-
-
Notifications
You must be signed in to change notification settings - Fork 120
Expand file tree
/
Copy pathdatabase.py
More file actions
1571 lines (1413 loc) · 59.1 KB
/
Copy pathdatabase.py
File metadata and controls
1571 lines (1413 loc) · 59.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# AudioMuse-AI - https://github.com/NeptuneHub/AudioMuse-AI
# Copyright (C) 2025 NeptuneHub
# SPDX-License-Identifier: AGPL-3.0-only
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License v3.0. See the LICENSE file
# in the project root or <https://github.com/NeptuneHub/AudioMuse-AI/blob/main/LICENSE>
"""Postgres data-access layer for the whole application.
Owns the per-request connection (via Flask ``g``), the embedded-server
lifecycle, the ``init_db`` schema bootstrap, and every read/write helper for
tasks, track analysis and embeddings, projections, and alchemy anchors/radios.
Main Features:
* Connection management plus ``init_db`` table/index creation and migrations.
* Task-status and history persistence with sanitized fields and capped history rows.
* Embedding, projection, and alchemy CRUD helpers shared by workers and the web app.
"""
import json
import logging
import sys
import time
import numpy as np
import psycopg2
from flask import g
from psycopg2.extras import DictCursor
import config
logger = logging.getLogger(__name__)
from tz_helper import UTC_NOW_SQL
from sanitization import sanitize_db_field
from config import (
TASK_STATUS_PENDING,
TASK_STATUS_STARTED,
TASK_STATUS_PROGRESS,
TASK_STATUS_SUCCESS,
TASK_STATUS_FAILURE,
TASK_STATUS_REVOKED,
)
TASK_HISTORY_MAX_ROWS = 10
MAX_LOG_ENTRIES_STORED = 10
MAP_PROJECTION_CACHE = None
_embedded_server = None
def get_db():
if 'db' not in g:
try:
g.db = psycopg2.connect(
config.DATABASE_URL,
connect_timeout=30,
keepalives_idle=600,
keepalives_interval=30,
keepalives_count=3,
options='-c statement_timeout=600000',
)
except psycopg2.OperationalError:
logger.exception("Failed to connect to database")
raise
return g.db
def close_db(e=None):
db = g.pop('db', None)
if db is not None:
db.close()
def start_embedded(data_dir):
global _embedded_server
import pgserver
_embedded_server = pgserver.get_server(data_dir)
return _embedded_server.get_uri()
def ensure_embedded_running(data_dir):
global _embedded_server
if _embedded_server is None:
return start_embedded(data_dir)
import pgserver
from pathlib import Path
try:
pgserver.PostgresServer._instances.pop(Path(data_dir).expanduser().resolve(), None)
except Exception:
pass
_embedded_server = pgserver.get_server(data_dir)
return _embedded_server.get_uri()
def stop_embedded():
global _embedded_server
if _embedded_server is not None:
_embedded_server.cleanup()
_embedded_server = None
def _build_task_note(task_type, details_obj, db):
if not isinstance(details_obj, dict):
details_obj = {}
t = (task_type or '').lower()
try:
if 'analysis' in t:
try:
with db.cursor() as cur:
cur.execute(
"SELECT details FROM task_status WHERE parent_task_id = %s AND status = 'SUCCESS'",
(details_obj.get('_task_id') or '',),
)
rows = cur.fetchall()
except Exception:
rows = []
songs = 0
for (d,) in rows or []:
if not d:
continue
try:
obj = json.loads(d)
if isinstance(obj, dict):
v = obj.get('tracks_analyzed')
if isinstance(v, (int, float)):
songs += int(v)
except Exception:
continue
if songs > 0:
return f"Songs analyzed: {songs}"
albums = details_obj.get('albums_completed') or details_obj.get(
'total_albums_processed'
)
if albums:
return f"Albums analyzed: {albums}"
return ''
if 'clean' in t:
for k in (
'tracks_deleted',
'orphans_removed',
'songs_cleaned',
'tracks_removed',
'deleted_count',
'cleaned_tracks',
):
v = details_obj.get(k)
if isinstance(v, (int, float)):
return f"Songs cleaned: {int(v)}"
return ''
if 'cluster' in t:
sampled = (
(details_obj.get('best_params') or {}).get('initial_subset_size')
if isinstance(details_obj.get('best_params'), dict)
else None
)
if sampled is None:
sampled = details_obj.get('sampled_songs') or details_obj.get('num_sampled_songs')
n_clusters = details_obj.get('num_playlists_created') or details_obj.get('num_clusters')
parts = []
if sampled:
parts.append(f"sampled: {int(sampled)}")
if n_clusters:
parts.append(f"clusters: {int(n_clusters)}")
return ' | '.join(parts)
except Exception as e:
logger.debug(f"task note builder failed for type={task_type}: {e}")
return ''
def record_task_history(task_id, task_type, status, duration_seconds=None, note=None, details=None):
if not task_id:
return
try:
db = get_db()
if note is None:
details_obj = details if isinstance(details, dict) else {}
details_obj = dict(details_obj)
details_obj['_task_id'] = task_id
note = _build_task_note(task_type, details_obj, db) or ''
if not note:
note = details_obj.get('status_message') or details_obj.get('message') or ''
with db.cursor() as cur:
cur.execute("SELECT 1 FROM task_history WHERE task_id = %s LIMIT 1", (task_id,))
if cur.fetchone():
return
cur.execute(
f"""
INSERT INTO task_history (task_id, task_type, status, duration_seconds, note, recorded_at)
VALUES (%s, %s, %s, %s, %s, {UTC_NOW_SQL})
""",
(task_id, task_type, status, duration_seconds, note),
)
cur.execute(
"""
DELETE FROM task_history
WHERE id NOT IN (
SELECT id FROM task_history ORDER BY recorded_at DESC, id DESC LIMIT %s
)
""",
(TASK_HISTORY_MAX_ROWS,),
)
db.commit()
except Exception as e:
logger.warning(f"record_task_history failed for {task_id}: {e}")
try:
db.rollback()
except Exception:
pass
def _normalize_task_details(details, status):
if not isinstance(details, dict):
return
if status == TASK_STATUS_SUCCESS:
details.pop('log_storage_info', None)
if not isinstance(details.get('log'), list) or not details.get('log'):
details['log'] = ["Task completed successfully."]
return
if not isinstance(details.get('log'), list):
return
log_list = details['log']
if len(log_list) <= MAX_LOG_ENTRIES_STORED:
details.pop('log_storage_info', None)
return
original_log_length = len(log_list)
details['log'] = log_list[-MAX_LOG_ENTRIES_STORED:]
details['log_storage_info'] = (
f"Log in DB truncated to last {MAX_LOG_ENTRIES_STORED} entries. Original length: {original_log_length}."
)
def _maybe_record_task_history(db, task_id, task_type, status, parent_task_id, details, current_unix_time):
if parent_task_id is not None:
return
if status not in (TASK_STATUS_SUCCESS, TASK_STATUS_FAILURE, TASK_STATUS_REVOKED):
return
if not task_type or task_type == 'unknown':
return
duration_s = None
try:
with db.cursor() as hist_cur:
hist_cur.execute(
"SELECT start_time, end_time FROM task_status WHERE task_id = %s",
(task_id,),
)
row = hist_cur.fetchone()
if row and row[0] is not None:
end = row[1] if row[1] is not None else current_unix_time
duration_s = max(0.0, float(end) - float(row[0]))
except Exception:
pass
record_task_history(task_id, task_type, status, duration_s, details=details)
def save_task_status(
task_id,
task_type,
status=TASK_STATUS_PENDING,
parent_task_id=None,
sub_type_identifier=None,
progress=0,
details=None,
):
db = get_db()
current_unix_time = time.time()
if details is not None:
_normalize_task_details(details, status)
details_json = json.dumps(details) if details is not None else None
cur = db.cursor()
try:
cur.execute(
"""
INSERT INTO task_status (task_id, parent_task_id, task_type, sub_type_identifier, status, progress, details, timestamp, start_time, end_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), %s, CASE WHEN %s IN ('SUCCESS', 'FAILURE', 'REVOKED') THEN %s ELSE NULL END)
ON CONFLICT (task_id) DO UPDATE SET
status = EXCLUDED.status,
parent_task_id = EXCLUDED.parent_task_id,
sub_type_identifier = EXCLUDED.sub_type_identifier,
progress = EXCLUDED.progress,
details = EXCLUDED.details,
timestamp = NOW(),
start_time = COALESCE(task_status.start_time, %s),
end_time = CASE
WHEN EXCLUDED.status IN ('SUCCESS', 'FAILURE', 'REVOKED') AND task_status.end_time IS NULL
THEN %s
ELSE task_status.end_time
END
""",
(
task_id,
parent_task_id,
task_type,
sub_type_identifier,
status,
progress,
details_json,
current_unix_time,
status,
current_unix_time,
current_unix_time,
current_unix_time,
),
)
db.commit()
except psycopg2.Error:
logger.exception(f"DB Error saving task status for {task_id}")
try:
db.rollback()
logger.info(f"DB transaction rolled back for task status update of {task_id}.")
except psycopg2.Error:
logger.exception(f"DB Error during rollback for task status {task_id}")
finally:
cur.close()
try:
_maybe_record_task_history(
db, task_id, task_type, status, parent_task_id, details, current_unix_time
)
except Exception as e_hist:
logger.debug(f"history record skipped for {task_id}: {e_hist}")
def get_task_info_from_db(task_id):
db = get_db()
cur = db.cursor(cursor_factory=DictCursor)
cur.execute(
"""
SELECT
task_id, parent_task_id, task_type, sub_type_identifier, status, progress, details, timestamp, start_time, end_time
FROM task_status
WHERE task_id = %s
""",
(task_id,),
)
row = cur.fetchone()
cur.close()
if not row:
return None
row_dict = dict(row)
current_unix_time = time.time()
start_time = row_dict.get('start_time')
end_time = row_dict.get('end_time')
if start_time is None:
row_dict['running_time_seconds'] = 0.0
else:
effective_end_time = end_time if end_time is not None else current_unix_time
row_dict['running_time_seconds'] = max(0, effective_end_time - start_time)
return row_dict
def get_score_data_by_ids(item_ids_list):
if not item_ids_list:
return []
conn = get_db()
cur = conn.cursor(cursor_factory=DictCursor)
query = """
SELECT s.item_id, s.title, s.author, s.album, s.album_artist, s.tempo, s.key, s.scale, s.mood_vector, s.energy, s.other_features, s.year, s.rating, s.file_path
FROM score s
WHERE s.item_id IN %s
"""
try:
cur.execute(query, (tuple(item_ids_list),))
rows = cur.fetchall()
except Exception:
logger.exception("Error fetching score data by IDs")
rows = []
finally:
cur.close()
return [dict(row) for row in rows]
def get_tracks_by_ids(item_ids_list):
if not item_ids_list:
return []
conn = get_db()
cur = conn.cursor(cursor_factory=DictCursor)
item_ids_str = [str(item_id) for item_id in item_ids_list]
query = """
SELECT s.item_id, s.title, s.author, s.album, s.album_artist, s.tempo, s.key, s.scale, s.mood_vector, s.energy, s.other_features, s.year, s.rating, s.file_path, e.embedding
FROM score s
LEFT JOIN embedding e ON s.item_id = e.item_id
WHERE s.item_id IN %s
"""
cur.execute(query, (tuple(item_ids_str),))
rows = cur.fetchall()
cur.close()
processed_rows = []
for row in rows:
row_dict = dict(row)
if row_dict.get('embedding'):
row_dict['embedding_vector'] = np.frombuffer(row_dict['embedding'], dtype=np.float32)
else:
row_dict['embedding_vector'] = np.array([])
processed_rows.append(row_dict)
return processed_rows
def load_map_projection(index_name, force_reload=False):
global MAP_PROJECTION_CACHE
if (
not force_reload
and MAP_PROJECTION_CACHE
and MAP_PROJECTION_CACHE.get('index_name') == index_name
):
logger.info(f"Map projection '{index_name}' already loaded in cache. Skipping reload.")
return MAP_PROJECTION_CACHE.get('id_map'), MAP_PROJECTION_CACHE.get('projection')
logger.info(f"Attempting to load map projection '{index_name}' from database into memory...")
conn = get_db()
cur = conn.cursor()
try:
cur.execute(
"SELECT projection_data, id_map_json FROM map_projection_data WHERE index_name = %s",
(index_name,),
)
row = cur.fetchone()
if row and row[0] is not None:
proj_blob, id_map_json = row[0], row[1]
else:
import re
from tasks.index_build_helpers import reassemble_segmented_id_map
cur.execute(
"SELECT index_name, projection_data, id_map_json FROM map_projection_data WHERE index_name LIKE %s ESCAPE '\\'",
(index_name.replace('_', r'\_') + r"\_%\_%",),
)
candidates = cur.fetchall()
if not candidates:
logger.warning(
f"Map projection '{index_name}' not found in the database. Cache will be empty."
)
return None, None
seg_pattern = re.compile(rf"^{re.escape(index_name)}_(\d+)_(\d+)$")
parts = []
total_expected = None
for name, part_blob, part_id_map in candidates:
m = seg_pattern.match(name)
if not m:
continue
part_no = int(m.group(1))
total = int(m.group(2))
if total_expected is None:
total_expected = total
elif total_expected != total:
logger.error(
f"Map projection segment total mismatch for '{index_name}' ({total_expected} vs {total}). Aborting load."
)
return None, None
parts.append((part_no, part_blob, part_id_map))
if total_expected is None or len(parts) != total_expected:
logger.error(
f"Incomplete map projection segments for '{index_name}': expected {total_expected}, found {len(parts)}. Aborting load."
)
return None, None
parts.sort(key=lambda p: p[0])
proj_blob = b"".join(bytes(p[1]) for p in parts if p[1])
id_map_json = reassemble_segmented_id_map((p[0], p[2]) for p in parts)
proj = np.frombuffer(proj_blob, dtype=np.float32)
if proj.size % 2 == 0:
proj = proj.reshape((-1, 2))
id_map = json.loads(id_map_json)
MAP_PROJECTION_CACHE = {'index_name': index_name, 'id_map': id_map, 'projection': proj}
logger.info(
f"Map projection '{index_name}' with {len(id_map)} items loaded successfully into memory."
)
return id_map, proj
except Exception:
logger.exception("Failed to load map projection")
return None, None
finally:
cur.close()
def _valid_year(year_value):
if 1000 <= year_value <= 2100:
return year_value
return None
def _parse_year_parts(parts):
try:
if len(parts[0]) == 4:
result = _valid_year(int(parts[0]))
if result is not None:
return result
if len(parts[2]) == 4:
result = _valid_year(int(parts[2]))
if result is not None:
return result
if len(parts[2]) == 2:
year = int(parts[2])
year += 2000 if year < 30 else 1900
return _valid_year(year)
except (ValueError, TypeError, IndexError):
pass
return None
def _parse_year_from_date(year_value):
if year_value is None:
return None
year_str = str(year_value).strip()
if not year_str:
return None
try:
result = _valid_year(int(year_str))
if result is not None:
return result
except (ValueError, TypeError):
pass
parts = year_str.replace('/', '-').split('-')
if len(parts) == 3:
return _parse_year_parts(parts)
return None
def _clamp_rating(rating):
if rating is None:
return None
try:
rating = int(rating)
if rating < 0 or rating > 5:
return None
return rating
except (ValueError, TypeError):
return None
def save_track_analysis_and_embedding(
item_id,
title,
author,
tempo,
key,
scale,
moods,
embedding_vector,
energy=None,
other_features=None,
album=None,
album_artist=None,
year=None,
rating=None,
file_path=None,
):
title = sanitize_db_field(title, max_length=500, field_name="title")
author = sanitize_db_field(author, max_length=200, field_name="author")
album = sanitize_db_field(album, max_length=200, field_name="album")
album_artist = sanitize_db_field(album_artist, max_length=200, field_name="album_artist")
key = sanitize_db_field(key, max_length=10, field_name="key")
scale = sanitize_db_field(scale, max_length=10, field_name="scale")
other_features = sanitize_db_field(other_features, max_length=2000, field_name="other_features")
year = _parse_year_from_date(year)
rating = _clamp_rating(rating)
file_path = sanitize_db_field(file_path, max_length=1000, field_name="file_path")
mood_str = ','.join(f"{k}:{v:.3f}" for k, v in moods.items())
conn = get_db()
cur = conn.cursor()
try:
cur.execute(
"""
INSERT INTO score (item_id, title, author, tempo, key, scale, mood_vector, energy, other_features, album, album_artist, year, rating, file_path)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (item_id) DO UPDATE SET
title = EXCLUDED.title,
author = EXCLUDED.author,
tempo = EXCLUDED.tempo,
key = EXCLUDED.key,
scale = EXCLUDED.scale,
mood_vector = EXCLUDED.mood_vector,
energy = EXCLUDED.energy,
other_features = EXCLUDED.other_features,
album = EXCLUDED.album,
album_artist = EXCLUDED.album_artist,
year = EXCLUDED.year,
rating = EXCLUDED.rating,
file_path = EXCLUDED.file_path
""",
(
item_id,
title,
author,
tempo,
key,
scale,
mood_str,
energy,
other_features,
album,
album_artist,
year,
rating,
file_path,
),
)
if isinstance(embedding_vector, np.ndarray) and embedding_vector.size > 0:
embedding_blob = embedding_vector.astype(np.float32).tobytes()
cur.execute(
"""
INSERT INTO embedding (item_id, embedding) VALUES (%s, %s)
ON CONFLICT (item_id) DO UPDATE SET embedding = EXCLUDED.embedding
""",
(item_id, psycopg2.Binary(embedding_blob)),
)
conn.commit()
except Exception:
conn.rollback()
logger.exception("Error saving track analysis and embedding for %s", item_id)
raise
finally:
cur.close()
def save_clap_embedding(item_id, clap_embedding_vector):
if clap_embedding_vector is None or (
isinstance(clap_embedding_vector, np.ndarray) and clap_embedding_vector.size == 0
):
return
conn = get_db()
cur = conn.cursor()
try:
embedding_blob = clap_embedding_vector.astype(np.float32).tobytes()
cur.execute(
"""
INSERT INTO clap_embedding (item_id, embedding) VALUES (%s, %s)
ON CONFLICT (item_id) DO UPDATE SET embedding = EXCLUDED.embedding
""",
(item_id, psycopg2.Binary(embedding_blob)),
)
conn.commit()
except Exception:
conn.rollback()
logger.exception(f"Error saving CLAP embedding for {item_id}")
raise
finally:
cur.close()
def get_clap_embedding(item_id):
conn = get_db()
cur = conn.cursor()
try:
cur.execute("SELECT embedding FROM clap_embedding WHERE item_id = %s", (item_id,))
row = cur.fetchone()
if row and row[0]:
return np.frombuffer(row[0], dtype=np.float32)
return None
except Exception:
logger.exception(f"Error loading CLAP embedding for {item_id}")
return None
finally:
cur.close()
def save_lyrics_embedding(item_id, lyrics_embedding_vector, axis_vector=None):
if lyrics_embedding_vector is None or (
isinstance(lyrics_embedding_vector, np.ndarray) and lyrics_embedding_vector.size == 0
):
return
conn = get_db()
cur = conn.cursor()
try:
embedding_blob = (
lyrics_embedding_vector.astype(np.float32).tobytes()
if isinstance(lyrics_embedding_vector, np.ndarray)
else np.asarray(lyrics_embedding_vector, dtype=np.float32).tobytes()
)
axis_blob = None
if axis_vector is not None:
arr = (
axis_vector
if isinstance(axis_vector, np.ndarray)
else np.asarray(axis_vector, dtype=np.float32)
)
if arr.size > 0:
axis_blob = arr.astype(np.float32, copy=False).tobytes()
cur.execute(
"""
INSERT INTO lyrics_embedding (item_id, embedding, axis_vector) VALUES (%s, %s, %s)
ON CONFLICT (item_id) DO UPDATE SET embedding = EXCLUDED.embedding, axis_vector = EXCLUDED.axis_vector, updated_at = CURRENT_TIMESTAMP
""",
(
item_id,
psycopg2.Binary(embedding_blob),
psycopg2.Binary(axis_blob) if axis_blob is not None else None,
),
)
conn.commit()
except Exception:
conn.rollback()
logger.exception(f"Error saving lyrics embedding for {item_id}")
raise
finally:
cur.close()
ARTIST_PROJECTION_CACHE = None
def init_db():
db = get_db()
with db.cursor() as cur:
cur.execute("SELECT pg_advisory_lock(726354821)")
try:
if sys.platform == 'win32':
for ext in ('unaccent', 'pg_trgm'):
cur.execute("SAVEPOINT ext_create")
try:
cur.execute(f'CREATE EXTENSION IF NOT EXISTS {ext}')
cur.execute("RELEASE SAVEPOINT ext_create")
except Exception:
logger.warning("Extension %s not available -- skipping", ext)
cur.execute("ROLLBACK TO SAVEPOINT ext_create")
else:
cur.execute('CREATE EXTENSION IF NOT EXISTS unaccent')
cur.execute('CREATE EXTENSION IF NOT EXISTS pg_trgm')
cur.execute(
"CREATE TABLE IF NOT EXISTS score (item_id TEXT PRIMARY KEY, title TEXT, author TEXT, album TEXT, album_artist TEXT, tempo REAL, key TEXT, scale TEXT, mood_vector TEXT)"
)
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'energy')"
)
if not cur.fetchone()[0]:
logger.info("Adding 'energy' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN energy REAL")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'other_features')"
)
if not cur.fetchone()[0]:
logger.info("Adding 'other_features' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN other_features TEXT")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'album')"
)
if not cur.fetchone()[0]:
logger.info("Adding 'album' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN album TEXT")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'album_artist')"
)
if not cur.fetchone()[0]:
logger.info("Adding 'album_artist' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN album_artist TEXT")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'year')"
)
if not cur.fetchone()[0]:
logger.info("Adding 'year' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN year INTEGER")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'rating')"
)
if not cur.fetchone()[0]:
logger.info("Adding 'rating' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN rating INTEGER")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'file_path')"
)
if not cur.fetchone()[0]:
logger.info("Adding 'file_path' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN file_path TEXT")
cur.execute(
"SELECT is_generated FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'search_u'"
)
row = cur.fetchone()
search_u_generated = row and row[0] == 'ALWAYS'
if search_u_generated:
logger.info(
"Dropping legacy generated 'search_u' column to replace it with a trigger-updated column."
)
cur.execute("ALTER TABLE score DROP COLUMN IF EXISTS search_u")
row = None
if not row:
logger.info("Adding 'search_u' column to 'score' table.")
cur.execute("ALTER TABLE score ADD COLUMN search_u TEXT")
if sys.platform == 'win32':
cur.execute("SAVEPOINT search_setup")
try:
cur.execute(
"CREATE OR REPLACE FUNCTION immutable_unaccent(text) RETURNS text LANGUAGE sql IMMUTABLE AS $$ SELECT public.unaccent($1) $$;"
)
cur.execute("""
CREATE OR REPLACE FUNCTION score_search_u_sync() RETURNS trigger LANGUAGE plpgsql AS $$
BEGIN
NEW.search_u := lower(immutable_unaccent(concat_ws(' ', NEW.title, NEW.author, NEW.album)));
RETURN NEW;
END;
$$;
""")
cur.execute("DROP TRIGGER IF EXISTS score_search_u_sync_trigger ON score")
cur.execute("""
CREATE TRIGGER score_search_u_sync_trigger
BEFORE INSERT OR UPDATE ON score
FOR EACH ROW
EXECUTE FUNCTION score_search_u_sync();
""")
cur.execute(
"UPDATE score SET search_u = lower(immutable_unaccent(concat_ws(' ', title, author, album))) WHERE search_u IS NULL"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS score_search_u_trgm ON score USING gin (search_u gin_trgm_ops)"
)
cur.execute("RELEASE SAVEPOINT search_setup")
except Exception:
logger.warning(
"unaccent/pg_trgm extensions not available -- accent-insensitive search disabled"
)
cur.execute("ROLLBACK TO SAVEPOINT search_setup")
else:
cur.execute(
"CREATE OR REPLACE FUNCTION immutable_unaccent(text) RETURNS text LANGUAGE sql IMMUTABLE AS $$ SELECT public.unaccent($1) $$;"
)
cur.execute("""
CREATE OR REPLACE FUNCTION score_search_u_sync() RETURNS trigger LANGUAGE plpgsql AS $$
BEGIN
NEW.search_u := lower(immutable_unaccent(concat_ws(' ', NEW.title, NEW.author, NEW.album)));
RETURN NEW;
END;
$$;
""")
cur.execute("DROP TRIGGER IF EXISTS score_search_u_sync_trigger ON score")
cur.execute("""
CREATE TRIGGER score_search_u_sync_trigger
BEFORE INSERT OR UPDATE ON score
FOR EACH ROW
EXECUTE FUNCTION score_search_u_sync();
""")
cur.execute(
"UPDATE score SET search_u = lower(immutable_unaccent(concat_ws(' ', title, author, album))) WHERE search_u IS NULL"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS score_search_u_trgm ON score USING gin (search_u gin_trgm_ops)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_score_album_artist_album ON score (album_artist, album)"
)
cur.execute("CREATE INDEX IF NOT EXISTS idx_score_author ON score (author)")
cur.execute(
"CREATE TABLE IF NOT EXISTS playlist (id SERIAL PRIMARY KEY, playlist_name TEXT, item_id TEXT, title TEXT, author TEXT, UNIQUE (playlist_name, item_id))"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS task_status (id SERIAL PRIMARY KEY, task_id TEXT UNIQUE NOT NULL, parent_task_id TEXT, task_type TEXT NOT NULL, sub_type_identifier TEXT, status TEXT, progress INTEGER DEFAULT 0, details TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_task_status_parent ON task_status (parent_task_id)"
)
for col_name in ['start_time', 'end_time']:
cur.execute(
"SELECT data_type FROM information_schema.columns WHERE table_name = 'task_status' AND column_name = %s",
(col_name,),
)
if not cur.fetchone():
cur.execute(f"ALTER TABLE task_status ADD COLUMN {col_name} DOUBLE PRECISION")
cur.execute("""
CREATE TABLE IF NOT EXISTS task_history (
id SERIAL PRIMARY KEY,
recorded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
task_id TEXT,
task_type TEXT,
status TEXT,
duration_seconds DOUBLE PRECISION,
note TEXT
)
""")
cur.execute(
"CREATE TABLE IF NOT EXISTS embedding (item_id TEXT PRIMARY KEY, FOREIGN KEY (item_id) REFERENCES score (item_id) ON DELETE CASCADE)"
)
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'embedding' AND column_name = 'embedding')"
)
if not cur.fetchone()[0]:
cur.execute("ALTER TABLE embedding ADD COLUMN embedding BYTEA")
cur.execute(
"CREATE TABLE IF NOT EXISTS lyrics_embedding (item_id TEXT PRIMARY KEY, FOREIGN KEY (item_id) REFERENCES score (item_id) ON DELETE CASCADE)"
)
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'lyrics_embedding' AND column_name = 'embedding')"
)
if not cur.fetchone()[0]:
cur.execute("ALTER TABLE lyrics_embedding ADD COLUMN embedding BYTEA")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'lyrics_embedding' AND column_name = 'axis_vector')"
)
if not cur.fetchone()[0]:
cur.execute("ALTER TABLE lyrics_embedding ADD COLUMN axis_vector BYTEA")
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'lyrics_embedding' AND column_name = 'updated_at')"
)
if not cur.fetchone()[0]:
cur.execute(
"ALTER TABLE lyrics_embedding ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS clap_embedding (item_id TEXT PRIMARY KEY, FOREIGN KEY (item_id) REFERENCES score (item_id) ON DELETE CASCADE)"
)
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'clap_embedding' AND column_name = 'embedding')"
)
if not cur.fetchone()[0]:
cur.execute("ALTER TABLE clap_embedding ADD COLUMN embedding BYTEA")
cur.execute("DROP TABLE IF EXISTS voyager_index_data")
cur.execute("DROP TABLE IF EXISTS clap_index_data")
cur.execute("DROP TABLE IF EXISTS lyrics_index_data")
cur.execute("DROP TABLE IF EXISTS lyrics_axes_index_data")
cur.execute("DROP TABLE IF EXISTS artist_index_data")
cur.execute(
"CREATE TABLE IF NOT EXISTS artist_metadata_data (name VARCHAR(255) PRIMARY KEY, blob_data BYTEA NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS ivf_dir (name VARCHAR(255) PRIMARY KEY, blob_data BYTEA NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS ivf_cell (index_name VARCHAR(255) NOT NULL, cell_id INTEGER NOT NULL, cell_data BYTEA NOT NULL, PRIMARY KEY (index_name, cell_id))"
)
cur.execute("ALTER TABLE ivf_cell ALTER COLUMN cell_data SET STORAGE EXTERNAL")
cur.execute(
"CREATE TABLE IF NOT EXISTS map_projection_data (index_name VARCHAR(255) PRIMARY KEY, projection_data BYTEA NOT NULL, id_map_json TEXT NOT NULL, embedding_dimension INTEGER NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS artist_component_projection (index_name VARCHAR(255) PRIMARY KEY, projection_data BYTEA NOT NULL, artist_component_map_json TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS cron (id SERIAL PRIMARY KEY, name TEXT, task_type TEXT NOT NULL, cron_expr TEXT NOT NULL, enabled BOOLEAN DEFAULT FALSE, last_run DOUBLE PRECISION, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS audiomuse_users (id SERIAL PRIMARY KEY, username TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, role TEXT NOT NULL DEFAULT 'user', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
)
cur.execute(
"ALTER TABLE audiomuse_users ADD COLUMN IF NOT EXISTS role TEXT NOT NULL DEFAULT 'user'"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS dashboard_stats ("
"id INTEGER PRIMARY KEY, "
"updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, "
"content JSONB NOT NULL DEFAULT '{}'::jsonb, "
"indexes JSONB NOT NULL DEFAULT '[]'::jsonb, "
"CONSTRAINT dashboard_stats_singleton CHECK (id = 1))"
)
cur.execute(
"SELECT COUNT(*) FROM information_schema.table_constraints "
"WHERE table_name = 'dashboard_stats' AND constraint_type = 'PRIMARY KEY'"
)
row = cur.fetchone()
if row and row[0] == 0:
logger.info(
"Cleaning dashboard_stats and adding missing primary key constraint to dashboard_stats.id"
)
cur.execute("DELETE FROM dashboard_stats")
cur.execute(
"ALTER TABLE dashboard_stats ADD CONSTRAINT dashboard_stats_pkey PRIMARY KEY (id)"
)
cur.execute(
"CREATE TABLE IF NOT EXISTS artist_mapping (artist_name TEXT PRIMARY KEY, artist_id TEXT)"
)
cur.execute(