-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspeaker_diarization.py
More file actions
163 lines (138 loc) · 6.49 KB
/
speaker_diarization.py
File metadata and controls
163 lines (138 loc) · 6.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import json
import os
from datetime import timedelta
from pyannote.audio import Pipeline
import torch
import sys
import logging
import argparse
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Color map for each distinct speaker label
SPEAKER_COLORS = {
"SPEAKER_00": "#FF5733", # Orange-red
"SPEAKER_01": "#33FF57", # Light green
"SPEAKER_02": "#33FFF6", # Cyan
"SPEAKER_03": "#FF33F6", # Pink
"SPEAKER_04": "#3357FF", # Blue
}
def timedelta_to_srt_timestamp(td: timedelta):
"""Convert a timedelta to SRT-style timestamp (HH:MM:SS,mmm)."""
total_sec = int(td.total_seconds())
hours = total_sec // 3600
minutes = (total_sec % 3600) // 60
seconds = (total_sec % 60)
milliseconds = int((td.total_seconds() - total_sec) * 1000)
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
def get_auth_token():
"""Get HuggingFace auth token from environment variable."""
token = os.getenv("HF_TOKEN")
if not token:
logger.error("HuggingFace authentication token not found!")
logger.error("\nTo use speaker diarization, you need to:")
logger.error("1. Create a .env file in the project root")
logger.error("2. Add your HuggingFace token to the .env file:")
logger.error("\nHF_TOKEN=your_token_here\n")
logger.error("3. Make sure you've accepted the model license at:")
logger.error(" https://huggingface.co/pyannote/speaker-diarization-3.1")
sys.exit(1)
return token
def process_audio(audio_path, output_json):
"""Process audio file for speaker diarization."""
# Get authentication token
auth_token = get_auth_token()
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
try:
# Initialize the pipeline with the pretrained model
logger.info("Loading diarization model...")
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=auth_token
).to(device)
except Exception as e:
logger.error(f"\nError loading the model: {str(e)}")
logger.error("\nThis might be because:")
logger.error("1. You haven't accepted the model's license agreement")
logger.error("2. Your token doesn't have the necessary permissions")
logger.error("3. The model is temporarily unavailable")
logger.error("\nPlease visit https://huggingface.co/pyannote/speaker-diarization-3.1")
logger.error("and make sure you've accepted the license agreement.")
sys.exit(1)
logger.info("Model loaded successfully. Processing audio...")
try:
# Perform diarization
logger.debug("Running diarization pipeline...")
diarization = pipeline(audio_path)
# Process speaker segments
speaker_data = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
start_td = timedelta(seconds=turn.start)
end_td = timedelta(seconds=turn.end)
speaker_label = f"SPEAKER_{speaker[-2:]}" if speaker.startswith("SPEAKER_") else speaker
color = SPEAKER_COLORS.get(speaker_label, "#FFFFFF") # Default to white if no mapping
speaker_data.append({
"start": timedelta_to_srt_timestamp(start_td),
"end": timedelta_to_srt_timestamp(end_td),
"speaker": speaker_label,
"color": color
})
logger.debug(f"Processed segment: {speaker_label} {start_td} -> {end_td}")
# Write results to JSON
with open(output_json, "w") as f:
json.dump(speaker_data, f, indent=4)
logger.info(f"Successfully wrote speaker diarization data to {output_json}")
logger.info(f"Found {len(set(s['speaker'] for s in speaker_data))} distinct speakers")
return True
except Exception as e:
logger.error(f"\nError processing audio: {str(e)}")
return False
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(
description='''
Perform speaker diarization on an audio file to identify different speakers.
This script processes rec16.wav (16kHz mono WAV file) and identifies different
speakers throughout the audio, creating timestamps for each speaker segment.
Input:
- rec16.wav: Audio file in the specified output directory
Output:
- speaker_segments.json: JSON file containing:
* Timestamps for each speech segment
* Speaker identification (SPEAKER_00, SPEAKER_01, etc.)
* Color codes for visualizing different speakers
The script uses the pyannote.audio library and requires a Hugging Face API token
(HF_TOKEN) in the .env file. You must also accept the model's license agreement at:
https://huggingface.co/pyannote/speaker-diarization-3.1
'''
)
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging')
parser.add_argument('-o', '--outputdir', default='output',
help='Directory for the input audio file output speaker segments file')
parser.add_argument('-r', '--regenerate', action='store_true',
help='Force regeneration of the speaker segments')
args = parser.parse_args()
# Configure logging
log_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=log_level, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
# Setup input/output paths
os.makedirs(args.outputdir, exist_ok=True)
audio_path = os.path.join(args.outputdir, 'rec16.wav')
output_json = os.path.join(args.outputdir, 'speaker_segments.json')
# Check if input file exists
if not os.path.exists(audio_path):
logger.error(f"Input audio file not found: {audio_path}")
sys.exit(1)
# Check if output file exists and handle regeneration
if os.path.exists(output_json) and not args.regenerate:
logger.info(f"Speaker segments file already exists at {output_json}. Use --regenerate to process again.")
sys.exit(0)
elif os.path.exists(output_json) and args.regenerate:
logger.info("Removing existing speaker segments file for regeneration")
os.remove(output_json)
# Process the audio
success = process_audio(audio_path, output_json)
sys.exit(0 if success else 1)