File size: 12,018 Bytes
696ecfb
a78e699
 
696ecfb
a78e699
696ecfb
 
a78e699
 
180ad9e
a37304e
180ad9e
 
 
 
ab19864
0888364
ddf32d4
696ecfb
1bc4ac3
 
696ecfb
a78e699
696ecfb
62a1cfd
 
 
 
 
 
 
7d3f4c9
 
 
 
 
 
 
 
3d9b565
7d3f4c9
 
3d9b565
 
 
 
 
7d3f4c9
 
3d9b565
 
7d3f4c9
3d9b565
7d3f4c9
3d9b565
7d3f4c9
 
3d9b565
7d3f4c9
 
 
388bc44
 
e236407
 
 
 
 
 
 
 
 
 
388bc44
 
e236407
 
 
 
 
 
388bc44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc87bd2
 
 
e236407
dc87bd2
 
e236407
388bc44
180ad9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a37304e
 
 
 
 
 
 
 
 
 
180ad9e
 
 
 
 
a37304e
180ad9e
 
 
1bc4ac3
 
180ad9e
 
 
 
 
 
a37304e
180ad9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a37304e
d2436f4
a37304e
 
 
180ad9e
a37304e
 
d2436f4
a37304e
180ad9e
 
a37304e
180ad9e
a37304e
180ad9e
 
a37304e
 
 
180ad9e
 
a37304e
 
251a198
 
 
 
 
 
 
1bc4ac3
 
 
 
 
 
 
 
696ecfb
1bc4ac3
 
a78e699
251a198
e236407
1bc4ac3
 
 
 
 
 
 
 
696ecfb
1bc4ac3
 
 
 
 
251a198
1bc4ac3
 
 
388bc44
dc87bd2
388bc44
1bc4ac3
c5b8c89
 
7d3f4c9
d2436f4
dc87bd2
1bc4ac3
dc87bd2
c5b8c89
1bc4ac3
d2436f4
 
 
 
 
 
 
1bc4ac3
 
 
 
 
 
 
 
658ff7a
251a198
 
 
 
1bc4ac3
 
 
 
696ecfb
dc87bd2
 
 
 
 
696ecfb
7d3f4c9
1bc4ac3
 
 
 
 
 
 
 
696ecfb
 
 
 
 
 
 
1bc4ac3
 
a78e699
 
 
 
 
 
1bc4ac3
 
 
 
 
 
 
 
 
a37304e
 
1bc4ac3
 
 
 
 
 
 
 
 
a37304e
 
 
 
 
 
 
 
 
1bc4ac3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import streamlit as st
import torch
import torchaudio
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
import tempfile
import os
import matplotlib.pyplot as plt
from pyannote.core import notebook
from huggingface_hub import HfApi, snapshot_download, hf_hub_download
from huggingface_hub.errors import LocalEntryNotFoundError, HfHubHTTPError
import requests
import pyannote.audio
import sys
import traceback
from speechbrain.pretrained import EncoderClassifier
from pydub import AudioSegment
import numpy as np

# Set page configuration
st.set_page_config(page_title="Optimized Speaker Diarization App", layout="wide")

st.title("Optimized Speaker Diarization App")

# Fetch HF_TOKEN from environment variable
HF_TOKEN = os.getenv("HF_TOKEN")

if not HF_TOKEN:
    st.error("HF_TOKEN not found in environment variables. Please set it in your Hugging Face Space secrets.")
    st.stop()



class ProgressHook:
    def __init__(self, status, progress_bar):
        self.status = status
        self.progress_bar = progress_bar
        self.total = 0
        self.completed = 0
        self.current_stage = ""

    def __call__(self, *args, **kwargs):
        if len(args) == 2 and isinstance(args[0], str):
            # Handle the case where it's called with (stage, data)
            self.current_stage = args[0]
            self.status.update(label=f"Processing: {self.current_stage}", state="running")
        elif 'completed' in kwargs and 'total' in kwargs:
            self.completed = kwargs['completed']
            self.total = kwargs['total']
            self._update_progress()
        elif len(args) == 2 and all(isinstance(arg, (int, float)) for arg in args):
            self.completed, self.total = args
            self._update_progress()
        
    def _update_progress(self):
        if self.total > 0:
            progress_percentage = min(self.completed / self.total, 1.0)
            self.status.update(label=f"Processing: {self.current_stage} - {progress_percentage:.1%} complete", state="running")
            self.progress_bar.progress(progress_percentage)



def preprocess_audio(tmp_path):
    # Load the audio file using pydub
    audio = AudioSegment.from_file(tmp_path)
    
    # Convert to mono if stereo
    if audio.channels == 2:
        audio = audio.set_channels(1)
    
    # Resample to 16kHz if necessary
    if audio.frame_rate != 16000:
        audio = audio.set_frame_rate(16000)
        st.info("Resampled audio to 16 kHz")
    
    # Convert to numpy array
    samples = np.array(audio.get_array_of_samples())
    
    # Convert to torch tensor
    waveform = torch.FloatTensor(samples).unsqueeze(0) / 32768.0  # Normalize to [-1, 1]
    
    # Determine the segment size (10 seconds at 16 kHz)
    segment_size = 160000
    
    # Calculate the number of segments
    num_segments = (waveform.shape[1] + segment_size - 1) // segment_size
    
    # Calculate the expected total length
    expected_length = num_segments * segment_size
    
    # Calculate the padding length
    padding_length = expected_length - waveform.shape[1]
    
    if padding_length > 0:
        # Pad the waveform with zeros
        pad = torch.zeros((waveform.shape[0], padding_length))
        waveform = torch.cat((waveform, pad), dim=1)
        st.info(f"Padded waveform with {padding_length} zeros")
    else:
        st.info("No padding needed")
    
    # Save the processed waveform to a temporary WAV file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as processed_file:
        processed_path = processed_file.name
        torchaudio.save(processed_path, waveform, 16000)
        st.info("Saved processed waveform to temporary WAV file")
    
    return waveform, 16000, processed_path

def check_versions():
    st.info("Checking package versions...")
    
    pyannote_version = pyannote.audio.__version__
    torch_version = torch.__version__
    
    st.write(f"Pyannote Audio version: {pyannote_version}")
    st.write(f"PyTorch version: {torch_version}")
    
    if pyannote_version < "3.1.0":
        st.warning("Your pyannote.audio version might be outdated. Consider upgrading to 3.1.0 or later.")
    
    if torch_version < "2.0.0":
        st.warning("Your PyTorch version might be outdated. Consider upgrading to 2.0.0 or later.")

check_versions()

def verify_token(token):
    api = HfApi()
    try:
        user_info = api.whoami(token=token)
        st.success(f"Token verified. Logged in as: {user_info['name']}")
        return True
    except Exception as e:
        st.error(f"Token verification failed: {str(e)}")
        return False

def check_hf_api():
    st.info("Checking Hugging Face API...")
    api_url = "https://huggingface.co/api/models/pyannote/speaker-diarization-3.1"
    headers = {"Authorization": f"Bearer {HF_TOKEN}"}
    
    try:
        response = requests.get(api_url, headers=headers)
        response.raise_for_status()
        st.success("Successfully connected to Hugging Face API")
        with st.expander("API Response"):
            st.json(response.json())
    except requests.exceptions.RequestException as e:
        st.error(f"Error connecting to Hugging Face API: {str(e)}")
        if response.status_code == 403:
            st.error("Access denied. Please check your token permissions.")
            st.info("Ensure your token has permission to access gated repositories.")
        st.code(response.text)

def verify_model_files():
    st.info("Verifying model files...")
    required_files = [
        "config.yaml",
        "pytorch_model.bin",
        "pyannote_serialized_object.bin"
    ]
    
    for file in required_files:
        try:
            path = hf_hub_download("pyannote/speaker-diarization-3.1", filename=file, use_auth_token=HF_TOKEN)
            if os.path.exists(path):
                st.success(f"File {file} found at {path}")
            else:
                st.error(f"File {file} not found")
        except Exception as e:
            st.error(f"Error downloading {file}: {str(e)}")


@st.cache_resource
def load_pipeline():
    try:
        st.info("Attempting to load the pipeline...")
        pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            use_auth_token=HF_TOKEN
        )
        st.success("Pipeline created successfully")

        if torch.cuda.is_available():
            st.info("Moving pipeline to GPU...")
            pipeline.to(torch.device("cuda"))
            st.success("Pipeline moved to GPU")
        
        return pipeline
    except Exception as e:
        st.error(f"Error loading pipeline: {str(e)}")
        st.error("Error details:")
        st.code(traceback.format_exc())
        raise e

@st.cache_resource
def load_speechbrain_model():
    st.info("Loading SpeechBrain model...")
    classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb")
    st.success("SpeechBrain model loaded successfully")
    return classifier

# Sidebar
with st.sidebar:
    st.header("Settings")
    show_advanced = st.toggle("Show Advanced Options")
    if show_advanced:
        num_speakers = st.number_input("Number of speakers (0 for auto)", min_value=0, value=0)
        min_speakers = st.number_input("Minimum number of speakers", min_value=1, value=1)
        max_speakers = st.number_input("Maximum number of speakers", min_value=1, value=5)

# Main content
tab1, tab2, tab3 = st.tabs(["Upload & Process", "Results", "Visualization"])



with tab1:
    uploaded_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'flac'])
    
    if uploaded_file is not None:
        # Save uploaded file temporarily
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
            tmp_file.write(uploaded_file.getvalue())
            tmp_path = tmp_file.name

        try:
            if verify_token(HF_TOKEN):
                check_hf_api()
                verify_model_files()
                pipeline = load_pipeline()
                speechbrain_model = load_speechbrain_model()
            else:
                st.stop()

            # Preprocess the audio file
            waveform, sample_rate, processed_path = preprocess_audio(tmp_path)

            with st.status("Processing audio...", expanded=True) as status:
                progress_bar = st.progress(0)
                
                progress_hook = ProgressHook(status, progress_bar)

                # Run the pipeline on the processed audio file
                diarization_args = {
                    "file": processed_path,
                    "hook": progress_hook
                }
                if show_advanced:
                    if num_speakers > 0:
                        diarization_args["num_speakers"] = num_speakers
                    else:
                        diarization_args["min_speakers"] = min_speakers
                        diarization_args["max_speakers"] = max_speakers

                diarization = pipeline(**diarization_args)
                status.update(label="Diarization complete!", state="complete")

            # Generate RTTM content
            rttm_content = ""
            for turn, _, speaker in diarization.itertracks(yield_label=True):
                rttm_line = f"SPEAKER {os.path.basename(tmp_path)} 1 {turn.start:.3f} {turn.duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
                rttm_content += rttm_line

            # Use SpeechBrain for speaker embedding (optional)
            embeddings = speechbrain_model.encode_batch(waveform)
            st.success("Speaker embeddings generated successfully")

        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
            st.error("Error details:")
            st.code(traceback.format_exc())

        finally:
            # Clean up the temporary files
            os.unlink(tmp_path)
            if 'processed_path' in locals():
                os.unlink(processed_path)


with tab2:
    if 'diarization' in locals():
        st.subheader("Diarization Results")
        st.metric("Number of speakers detected", len(diarization.labels()))
        
        with st.expander("RTTM Output"):
            st.text_area("RTTM Content", rttm_content, height=300)
        
        st.download_button(
            label="Download RTTM file",
            data=rttm_content,
            file_name="diarization.rttm",
            mime="text/plain"
        )

with tab3:
    if 'diarization' in locals():
        if st.button("Visualize Diarization"):
            fig, ax = plt.subplots(figsize=(10, 2))
            notebook.plot_diarization(diarization, ax=ax)
            plt.tight_layout()
            st.pyplot(fig)

# Debug Information
with st.expander("Debug Information"):
    st.write(f"Working directory: {os.getcwd()}")
    st.write(f"Files in working directory: {os.listdir()}")
    st.write(f"Python version: {sys.version.split()[0]}")
    st.write(f"PyTorch version: {torch.__version__}")
    st.write(f"Pyannote Audio version: {pyannote.audio.__version__}")
    st.write(f"CUDA available: {torch.cuda.is_available()}")
    st.write(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

# Token Permissions Instructions
with st.expander("Token Permissions"):
    st.markdown("""
    If you're encountering access issues, please ensure your Hugging Face token has the following permissions:
    1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
    2. Find your token or create a new one
    3. Ensure "Read" access is granted
    4. Check the box for "Access to gated repositories"
    5. Save the changes and try again
    """)

# Clear Cache Button
if st.button("Clear Cache"):
    import shutil
    cache_dir = "./model_cache"
    if os.path.exists(cache_dir):
        shutil.rmtree(cache_dir)
        st.success("Cache cleared successfully.")
    else:
        st.info("No cache directory found.")