|
|
import streamlit as st |
|
|
import torch |
|
|
import torchaudio |
|
|
from pyannote.audio import Pipeline |
|
|
from pyannote.audio.pipelines.utils.hook import ProgressHook |
|
|
import tempfile |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
from pyannote.core import notebook |
|
|
from huggingface_hub import HfApi, snapshot_download, hf_hub_download |
|
|
from huggingface_hub.errors import LocalEntryNotFoundError, HfHubHTTPError |
|
|
import requests |
|
|
import pyannote.audio |
|
|
import sys |
|
|
import traceback |
|
|
from speechbrain.pretrained import EncoderClassifier |
|
|
from pydub import AudioSegment |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Optimized Speaker Diarization App", layout="wide") |
|
|
|
|
|
st.title("Optimized Speaker Diarization App") |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
if not HF_TOKEN: |
|
|
st.error("HF_TOKEN not found in environment variables. Please set it in your Hugging Face Space secrets.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
|
|
|
class ProgressHook: |
|
|
def __init__(self, status, progress_bar): |
|
|
self.status = status |
|
|
self.progress_bar = progress_bar |
|
|
self.total = 0 |
|
|
self.completed = 0 |
|
|
self.current_stage = "" |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
if len(args) == 2 and isinstance(args[0], str): |
|
|
|
|
|
self.current_stage = args[0] |
|
|
self.status.update(label=f"Processing: {self.current_stage}", state="running") |
|
|
elif 'completed' in kwargs and 'total' in kwargs: |
|
|
self.completed = kwargs['completed'] |
|
|
self.total = kwargs['total'] |
|
|
self._update_progress() |
|
|
elif len(args) == 2 and all(isinstance(arg, (int, float)) for arg in args): |
|
|
self.completed, self.total = args |
|
|
self._update_progress() |
|
|
|
|
|
def _update_progress(self): |
|
|
if self.total > 0: |
|
|
progress_percentage = min(self.completed / self.total, 1.0) |
|
|
self.status.update(label=f"Processing: {self.current_stage} - {progress_percentage:.1%} complete", state="running") |
|
|
self.progress_bar.progress(progress_percentage) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_audio(tmp_path): |
|
|
|
|
|
audio = AudioSegment.from_file(tmp_path) |
|
|
|
|
|
|
|
|
if audio.channels == 2: |
|
|
audio = audio.set_channels(1) |
|
|
|
|
|
|
|
|
if audio.frame_rate != 16000: |
|
|
audio = audio.set_frame_rate(16000) |
|
|
st.info("Resampled audio to 16 kHz") |
|
|
|
|
|
|
|
|
samples = np.array(audio.get_array_of_samples()) |
|
|
|
|
|
|
|
|
waveform = torch.FloatTensor(samples).unsqueeze(0) / 32768.0 |
|
|
|
|
|
|
|
|
segment_size = 160000 |
|
|
|
|
|
|
|
|
num_segments = (waveform.shape[1] + segment_size - 1) // segment_size |
|
|
|
|
|
|
|
|
expected_length = num_segments * segment_size |
|
|
|
|
|
|
|
|
padding_length = expected_length - waveform.shape[1] |
|
|
|
|
|
if padding_length > 0: |
|
|
|
|
|
pad = torch.zeros((waveform.shape[0], padding_length)) |
|
|
waveform = torch.cat((waveform, pad), dim=1) |
|
|
st.info(f"Padded waveform with {padding_length} zeros") |
|
|
else: |
|
|
st.info("No padding needed") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as processed_file: |
|
|
processed_path = processed_file.name |
|
|
torchaudio.save(processed_path, waveform, 16000) |
|
|
st.info("Saved processed waveform to temporary WAV file") |
|
|
|
|
|
return waveform, 16000, processed_path |
|
|
|
|
|
def check_versions(): |
|
|
st.info("Checking package versions...") |
|
|
|
|
|
pyannote_version = pyannote.audio.__version__ |
|
|
torch_version = torch.__version__ |
|
|
|
|
|
st.write(f"Pyannote Audio version: {pyannote_version}") |
|
|
st.write(f"PyTorch version: {torch_version}") |
|
|
|
|
|
if pyannote_version < "3.1.0": |
|
|
st.warning("Your pyannote.audio version might be outdated. Consider upgrading to 3.1.0 or later.") |
|
|
|
|
|
if torch_version < "2.0.0": |
|
|
st.warning("Your PyTorch version might be outdated. Consider upgrading to 2.0.0 or later.") |
|
|
|
|
|
check_versions() |
|
|
|
|
|
def verify_token(token): |
|
|
api = HfApi() |
|
|
try: |
|
|
user_info = api.whoami(token=token) |
|
|
st.success(f"Token verified. Logged in as: {user_info['name']}") |
|
|
return True |
|
|
except Exception as e: |
|
|
st.error(f"Token verification failed: {str(e)}") |
|
|
return False |
|
|
|
|
|
def check_hf_api(): |
|
|
st.info("Checking Hugging Face API...") |
|
|
api_url = "https://huggingface.co/api/models/pyannote/speaker-diarization-3.1" |
|
|
headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
|
|
|
|
|
try: |
|
|
response = requests.get(api_url, headers=headers) |
|
|
response.raise_for_status() |
|
|
st.success("Successfully connected to Hugging Face API") |
|
|
with st.expander("API Response"): |
|
|
st.json(response.json()) |
|
|
except requests.exceptions.RequestException as e: |
|
|
st.error(f"Error connecting to Hugging Face API: {str(e)}") |
|
|
if response.status_code == 403: |
|
|
st.error("Access denied. Please check your token permissions.") |
|
|
st.info("Ensure your token has permission to access gated repositories.") |
|
|
st.code(response.text) |
|
|
|
|
|
def verify_model_files(): |
|
|
st.info("Verifying model files...") |
|
|
required_files = [ |
|
|
"config.yaml", |
|
|
"pytorch_model.bin", |
|
|
"pyannote_serialized_object.bin" |
|
|
] |
|
|
|
|
|
for file in required_files: |
|
|
try: |
|
|
path = hf_hub_download("pyannote/speaker-diarization-3.1", filename=file, use_auth_token=HF_TOKEN) |
|
|
if os.path.exists(path): |
|
|
st.success(f"File {file} found at {path}") |
|
|
else: |
|
|
st.error(f"File {file} not found") |
|
|
except Exception as e: |
|
|
st.error(f"Error downloading {file}: {str(e)}") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_pipeline(): |
|
|
try: |
|
|
st.info("Attempting to load the pipeline...") |
|
|
pipeline = Pipeline.from_pretrained( |
|
|
"pyannote/speaker-diarization-3.1", |
|
|
use_auth_token=HF_TOKEN |
|
|
) |
|
|
st.success("Pipeline created successfully") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
st.info("Moving pipeline to GPU...") |
|
|
pipeline.to(torch.device("cuda")) |
|
|
st.success("Pipeline moved to GPU") |
|
|
|
|
|
return pipeline |
|
|
except Exception as e: |
|
|
st.error(f"Error loading pipeline: {str(e)}") |
|
|
st.error("Error details:") |
|
|
st.code(traceback.format_exc()) |
|
|
raise e |
|
|
|
|
|
@st.cache_resource |
|
|
def load_speechbrain_model(): |
|
|
st.info("Loading SpeechBrain model...") |
|
|
classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb") |
|
|
st.success("SpeechBrain model loaded successfully") |
|
|
return classifier |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Settings") |
|
|
show_advanced = st.toggle("Show Advanced Options") |
|
|
if show_advanced: |
|
|
num_speakers = st.number_input("Number of speakers (0 for auto)", min_value=0, value=0) |
|
|
min_speakers = st.number_input("Minimum number of speakers", min_value=1, value=1) |
|
|
max_speakers = st.number_input("Maximum number of speakers", min_value=1, value=5) |
|
|
|
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["Upload & Process", "Results", "Visualization"]) |
|
|
|
|
|
|
|
|
|
|
|
with tab1: |
|
|
uploaded_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'flac']) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: |
|
|
tmp_file.write(uploaded_file.getvalue()) |
|
|
tmp_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
if verify_token(HF_TOKEN): |
|
|
check_hf_api() |
|
|
verify_model_files() |
|
|
pipeline = load_pipeline() |
|
|
speechbrain_model = load_speechbrain_model() |
|
|
else: |
|
|
st.stop() |
|
|
|
|
|
|
|
|
waveform, sample_rate, processed_path = preprocess_audio(tmp_path) |
|
|
|
|
|
with st.status("Processing audio...", expanded=True) as status: |
|
|
progress_bar = st.progress(0) |
|
|
|
|
|
progress_hook = ProgressHook(status, progress_bar) |
|
|
|
|
|
|
|
|
diarization_args = { |
|
|
"file": processed_path, |
|
|
"hook": progress_hook |
|
|
} |
|
|
if show_advanced: |
|
|
if num_speakers > 0: |
|
|
diarization_args["num_speakers"] = num_speakers |
|
|
else: |
|
|
diarization_args["min_speakers"] = min_speakers |
|
|
diarization_args["max_speakers"] = max_speakers |
|
|
|
|
|
diarization = pipeline(**diarization_args) |
|
|
status.update(label="Diarization complete!", state="complete") |
|
|
|
|
|
|
|
|
rttm_content = "" |
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True): |
|
|
rttm_line = f"SPEAKER {os.path.basename(tmp_path)} 1 {turn.start:.3f} {turn.duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n" |
|
|
rttm_content += rttm_line |
|
|
|
|
|
|
|
|
embeddings = speechbrain_model.encode_batch(waveform) |
|
|
st.success("Speaker embeddings generated successfully") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {str(e)}") |
|
|
st.error("Error details:") |
|
|
st.code(traceback.format_exc()) |
|
|
|
|
|
finally: |
|
|
|
|
|
os.unlink(tmp_path) |
|
|
if 'processed_path' in locals(): |
|
|
os.unlink(processed_path) |
|
|
|
|
|
|
|
|
with tab2: |
|
|
if 'diarization' in locals(): |
|
|
st.subheader("Diarization Results") |
|
|
st.metric("Number of speakers detected", len(diarization.labels())) |
|
|
|
|
|
with st.expander("RTTM Output"): |
|
|
st.text_area("RTTM Content", rttm_content, height=300) |
|
|
|
|
|
st.download_button( |
|
|
label="Download RTTM file", |
|
|
data=rttm_content, |
|
|
file_name="diarization.rttm", |
|
|
mime="text/plain" |
|
|
) |
|
|
|
|
|
with tab3: |
|
|
if 'diarization' in locals(): |
|
|
if st.button("Visualize Diarization"): |
|
|
fig, ax = plt.subplots(figsize=(10, 2)) |
|
|
notebook.plot_diarization(diarization, ax=ax) |
|
|
plt.tight_layout() |
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
with st.expander("Debug Information"): |
|
|
st.write(f"Working directory: {os.getcwd()}") |
|
|
st.write(f"Files in working directory: {os.listdir()}") |
|
|
st.write(f"Python version: {sys.version.split()[0]}") |
|
|
st.write(f"PyTorch version: {torch.__version__}") |
|
|
st.write(f"Pyannote Audio version: {pyannote.audio.__version__}") |
|
|
st.write(f"CUDA available: {torch.cuda.is_available()}") |
|
|
st.write(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}") |
|
|
|
|
|
|
|
|
with st.expander("Token Permissions"): |
|
|
st.markdown(""" |
|
|
If you're encountering access issues, please ensure your Hugging Face token has the following permissions: |
|
|
1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) |
|
|
2. Find your token or create a new one |
|
|
3. Ensure "Read" access is granted |
|
|
4. Check the box for "Access to gated repositories" |
|
|
5. Save the changes and try again |
|
|
""") |
|
|
|
|
|
|
|
|
if st.button("Clear Cache"): |
|
|
import shutil |
|
|
cache_dir = "./model_cache" |
|
|
if os.path.exists(cache_dir): |
|
|
shutil.rmtree(cache_dir) |
|
|
st.success("Cache cleared successfully.") |
|
|
else: |
|
|
st.info("No cache directory found.") |