Update app.py
Browse files
app.py
CHANGED
|
@@ -14,8 +14,8 @@ import pyannote.audio
|
|
| 14 |
import sys
|
| 15 |
import traceback
|
| 16 |
|
| 17 |
-
# Set page
|
| 18 |
-
st.set_page_config(page_title="Optimized Speaker Diarization App")
|
| 19 |
|
| 20 |
st.title("Optimized Speaker Diarization App")
|
| 21 |
|
|
@@ -62,7 +62,8 @@ def check_hf_api():
|
|
| 62 |
response = requests.get(api_url, headers=headers)
|
| 63 |
response.raise_for_status()
|
| 64 |
st.success("Successfully connected to Hugging Face API")
|
| 65 |
-
st.
|
|
|
|
| 66 |
except requests.exceptions.RequestException as e:
|
| 67 |
st.error(f"Error connecting to Hugging Face API: {str(e)}")
|
| 68 |
if response.status_code == 403:
|
|
@@ -129,67 +130,76 @@ def load_pipeline():
|
|
| 129 |
|
| 130 |
raise e
|
| 131 |
|
| 132 |
-
#
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
st.
|
| 137 |
-
num_speakers = st.sidebar.number_input("Number of speakers (0 for auto)", min_value=0, value=0)
|
| 138 |
-
min_speakers = st.sidebar.number_input("Minimum number of speakers", min_value=1, value=1)
|
| 139 |
-
max_speakers = st.sidebar.number_input("Maximum number of speakers", min_value=1, value=5)
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
else:
|
| 150 |
-
st.stop()
|
| 151 |
-
|
| 152 |
-
if uploaded_file is not None:
|
| 153 |
-
# Save uploaded file temporarily
|
| 154 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
| 155 |
-
tmp_file.write(uploaded_file.getvalue())
|
| 156 |
-
tmp_path = tmp_file.name
|
| 157 |
-
|
| 158 |
-
try:
|
| 159 |
-
# Set up progress hook
|
| 160 |
-
progress_text = st.empty()
|
| 161 |
-
progress_bar = st.progress(0)
|
| 162 |
-
|
| 163 |
-
def progress_hook(step: int, total: int, stage: str):
|
| 164 |
-
progress_text.text(f"Processing: {stage}")
|
| 165 |
-
progress_bar.progress(step / total)
|
| 166 |
-
|
| 167 |
-
# Run the pipeline on the audio file
|
| 168 |
-
with st.spinner('Processing audio...'):
|
| 169 |
-
diarization_args = {
|
| 170 |
-
"file": tmp_path,
|
| 171 |
-
"min_speakers": min_speakers,
|
| 172 |
-
"max_speakers": max_speakers,
|
| 173 |
-
"hook": ProgressHook(progress_hook)
|
| 174 |
-
}
|
| 175 |
-
if num_speakers > 0:
|
| 176 |
-
diarization_args["num_speakers"] = num_speakers
|
| 177 |
-
|
| 178 |
-
diarization = pipeline(**diarization_args)
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
rttm_content += rttm_line
|
| 187 |
|
| 188 |
-
#
|
| 189 |
-
|
| 190 |
-
st.text_area("RTTM Output", rttm_content, height=300)
|
| 191 |
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
st.download_button(
|
| 194 |
label="Download RTTM file",
|
| 195 |
data=rttm_content,
|
|
@@ -197,58 +207,34 @@ if uploaded_file is not None:
|
|
| 197 |
mime="text/plain"
|
| 198 |
)
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
st.write(f"Number of speakers detected: {len(diarization.labels())}")
|
| 203 |
-
|
| 204 |
-
# Visualize diarization
|
| 205 |
if st.button("Visualize Diarization"):
|
| 206 |
fig, ax = plt.subplots(figsize=(10, 2))
|
| 207 |
notebook.plot_diarization(diarization, ax=ax)
|
| 208 |
plt.tight_layout()
|
| 209 |
st.pyplot(fig)
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
else
|
| 220 |
-
st.info("Please upload an audio file to start.")
|
| 221 |
-
|
| 222 |
-
# Display usage instructions
|
| 223 |
-
st.sidebar.markdown("""
|
| 224 |
-
## Usage Instructions
|
| 225 |
-
1. Upload an audio file (WAV, MP3, or FLAC).
|
| 226 |
-
2. Adjust advanced options if needed.
|
| 227 |
-
3. Wait for the diarization process to complete.
|
| 228 |
-
4. View and download the RTTM results.
|
| 229 |
-
5. Optionally, visualize the diarization.
|
| 230 |
-
""")
|
| 231 |
-
|
| 232 |
-
# Display system information
|
| 233 |
-
st.sidebar.markdown(f"""
|
| 234 |
-
## System Information
|
| 235 |
-
- Python version: {sys.version.split()[0]}
|
| 236 |
-
- PyTorch version: {torch.__version__}
|
| 237 |
-
- Pyannote Audio version: {pyannote.audio.__version__}
|
| 238 |
-
- CUDA available: {torch.cuda.is_available()}
|
| 239 |
-
- Device: {"CUDA" if torch.cuda.is_available() else "CPU"}
|
| 240 |
-
""")
|
| 241 |
|
| 242 |
# Token Permissions Instructions
|
| 243 |
-
st.
|
| 244 |
-
|
| 245 |
-
If you're encountering access issues, please ensure your Hugging Face token has the following permissions:
|
| 246 |
-
1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
| 247 |
-
2. Find your token or create a new one
|
| 248 |
-
3. Ensure "Read" access is granted
|
| 249 |
-
4. Check the box for "Access to gated repositories"
|
| 250 |
-
5. Save the changes and try again
|
| 251 |
-
""")
|
| 252 |
|
| 253 |
# Clear Cache Button
|
| 254 |
if st.button("Clear Cache"):
|
|
@@ -258,9 +244,4 @@ if st.button("Clear Cache"):
|
|
| 258 |
shutil.rmtree(cache_dir)
|
| 259 |
st.success("Cache cleared successfully.")
|
| 260 |
else:
|
| 261 |
-
st.info("No cache directory found.")
|
| 262 |
-
|
| 263 |
-
# Debug Information
|
| 264 |
-
st.subheader("Debug Information")
|
| 265 |
-
st.write(f"Working directory: {os.getcwd()}")
|
| 266 |
-
st.write(f"Files in working directory: {os.listdir()}")
|
|
|
|
| 14 |
import sys
|
| 15 |
import traceback
|
| 16 |
|
| 17 |
+
# Set page configuration
|
| 18 |
+
st.set_page_config(page_title="Optimized Speaker Diarization App", layout="wide")
|
| 19 |
|
| 20 |
st.title("Optimized Speaker Diarization App")
|
| 21 |
|
|
|
|
| 62 |
response = requests.get(api_url, headers=headers)
|
| 63 |
response.raise_for_status()
|
| 64 |
st.success("Successfully connected to Hugging Face API")
|
| 65 |
+
with st.expander("API Response"):
|
| 66 |
+
st.json(response.json())
|
| 67 |
except requests.exceptions.RequestException as e:
|
| 68 |
st.error(f"Error connecting to Hugging Face API: {str(e)}")
|
| 69 |
if response.status_code == 403:
|
|
|
|
| 130 |
|
| 131 |
raise e
|
| 132 |
|
| 133 |
+
# Sidebar
|
| 134 |
+
with st.sidebar:
|
| 135 |
+
st.header("Settings")
|
| 136 |
+
show_advanced = st.toggle("Show Advanced Options")
|
| 137 |
+
if show_advanced:
|
| 138 |
+
num_speakers = st.number_input("Number of speakers (0 for auto)", min_value=0, value=0)
|
| 139 |
+
min_speakers = st.number_input("Minimum number of speakers", min_value=1, value=1)
|
| 140 |
+
max_speakers = st.number_input("Maximum number of speakers", min_value=1, value=5)
|
| 141 |
|
| 142 |
+
# Main content
|
| 143 |
+
tab1, tab2, tab3 = st.tabs(["Upload & Process", "Results", "Visualization"])
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
with tab1:
|
| 146 |
+
uploaded_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'flac'])
|
| 147 |
+
|
| 148 |
+
if uploaded_file is not None:
|
| 149 |
+
# Save uploaded file temporarily
|
| 150 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
| 151 |
+
tmp_file.write(uploaded_file.getvalue())
|
| 152 |
+
tmp_path = tmp_file.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
try:
|
| 155 |
+
if verify_token(HF_TOKEN):
|
| 156 |
+
check_hf_api()
|
| 157 |
+
verify_model_files()
|
| 158 |
+
pipeline = load_pipeline()
|
| 159 |
+
else:
|
| 160 |
+
st.stop()
|
| 161 |
+
|
| 162 |
+
with st.status("Processing audio...", expanded=True) as status:
|
| 163 |
+
# Set up progress hook
|
| 164 |
+
def progress_hook(step: int, total: int, stage: str):
|
| 165 |
+
status.update(label=f"Processing: {stage}", state="running")
|
| 166 |
+
st.progress(step / total)
|
| 167 |
+
|
| 168 |
+
# Run the pipeline on the audio file
|
| 169 |
+
diarization_args = {
|
| 170 |
+
"file": tmp_path,
|
| 171 |
+
"min_speakers": min_speakers if show_advanced else 1,
|
| 172 |
+
"max_speakers": max_speakers if show_advanced else 5,
|
| 173 |
+
"hook": ProgressHook(progress_hook)
|
| 174 |
+
}
|
| 175 |
+
if show_advanced and num_speakers > 0:
|
| 176 |
+
diarization_args["num_speakers"] = num_speakers
|
| 177 |
+
|
| 178 |
+
diarization = pipeline(**diarization_args)
|
| 179 |
+
status.update(label="Diarization complete!", state="complete")
|
| 180 |
+
|
| 181 |
+
# Generate RTTM content
|
| 182 |
+
rttm_content = ""
|
| 183 |
+
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
| 184 |
+
rttm_line = f"SPEAKER {os.path.basename(tmp_path)} 1 {turn.start:.3f} {turn.duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
|
| 185 |
+
rttm_content += rttm_line
|
| 186 |
|
| 187 |
+
except Exception as e:
|
| 188 |
+
st.error(f"An error occurred: {str(e)}")
|
| 189 |
+
st.error("Error details:")
|
| 190 |
+
st.code(traceback.format_exc())
|
|
|
|
| 191 |
|
| 192 |
+
# Clean up the temporary file
|
| 193 |
+
os.unlink(tmp_path)
|
|
|
|
| 194 |
|
| 195 |
+
with tab2:
|
| 196 |
+
if 'diarization' in locals():
|
| 197 |
+
st.subheader("Diarization Results")
|
| 198 |
+
st.metric("Number of speakers detected", len(diarization.labels()))
|
| 199 |
+
|
| 200 |
+
with st.expander("RTTM Output"):
|
| 201 |
+
st.text_area("RTTM Content", rttm_content, height=300)
|
| 202 |
+
|
| 203 |
st.download_button(
|
| 204 |
label="Download RTTM file",
|
| 205 |
data=rttm_content,
|
|
|
|
| 207 |
mime="text/plain"
|
| 208 |
)
|
| 209 |
|
| 210 |
+
with tab3:
|
| 211 |
+
if 'diarization' in locals():
|
|
|
|
|
|
|
|
|
|
| 212 |
if st.button("Visualize Diarization"):
|
| 213 |
fig, ax = plt.subplots(figsize=(10, 2))
|
| 214 |
notebook.plot_diarization(diarization, ax=ax)
|
| 215 |
plt.tight_layout()
|
| 216 |
st.pyplot(fig)
|
| 217 |
|
| 218 |
+
# Debug Information
|
| 219 |
+
with st.expander("Debug Information"):
|
| 220 |
+
st.write(f"Working directory: {os.getcwd()}")
|
| 221 |
+
st.write(f"Files in working directory: {os.listdir()}")
|
| 222 |
+
st.write(f"Python version: {sys.version.split()[0]}")
|
| 223 |
+
st.write(f"PyTorch version: {torch.__version__}")
|
| 224 |
+
st.write(f"Pyannote Audio version: {pyannote.audio.__version__}")
|
| 225 |
+
st.write(f"CUDA available: {torch.cuda.is_available()}")
|
| 226 |
+
st.write(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# Token Permissions Instructions
|
| 229 |
+
with st.expander("Token Permissions"):
|
| 230 |
+
st.markdown("""
|
| 231 |
+
If you're encountering access issues, please ensure your Hugging Face token has the following permissions:
|
| 232 |
+
1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
| 233 |
+
2. Find your token or create a new one
|
| 234 |
+
3. Ensure "Read" access is granted
|
| 235 |
+
4. Check the box for "Access to gated repositories"
|
| 236 |
+
5. Save the changes and try again
|
| 237 |
+
""")
|
| 238 |
|
| 239 |
# Clear Cache Button
|
| 240 |
if st.button("Clear Cache"):
|
|
|
|
| 244 |
shutil.rmtree(cache_dir)
|
| 245 |
st.success("Cache cleared successfully.")
|
| 246 |
else:
|
| 247 |
+
st.info("No cache directory found.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|