ricklon commited on
Commit
1bc4ac3
·
verified ·
1 Parent(s): 2b446d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -108
app.py CHANGED
@@ -14,8 +14,8 @@ import pyannote.audio
14
  import sys
15
  import traceback
16
 
17
- # Set page title
18
- st.set_page_config(page_title="Optimized Speaker Diarization App")
19
 
20
  st.title("Optimized Speaker Diarization App")
21
 
@@ -62,7 +62,8 @@ def check_hf_api():
62
  response = requests.get(api_url, headers=headers)
63
  response.raise_for_status()
64
  st.success("Successfully connected to Hugging Face API")
65
- st.json(response.json())
 
66
  except requests.exceptions.RequestException as e:
67
  st.error(f"Error connecting to Hugging Face API: {str(e)}")
68
  if response.status_code == 403:
@@ -129,67 +130,76 @@ def load_pipeline():
129
 
130
  raise e
131
 
132
- # File uploader
133
- uploaded_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'flac'])
 
 
 
 
 
 
134
 
135
- # Advanced options
136
- st.sidebar.header("Advanced Options")
137
- num_speakers = st.sidebar.number_input("Number of speakers (0 for auto)", min_value=0, value=0)
138
- min_speakers = st.sidebar.number_input("Minimum number of speakers", min_value=1, value=1)
139
- max_speakers = st.sidebar.number_input("Maximum number of speakers", min_value=1, value=5)
140
 
141
- if verify_token(HF_TOKEN):
142
- check_hf_api()
143
- verify_model_files()
144
- try:
145
- pipeline = load_pipeline()
146
- except Exception as e:
147
- st.error("Failed to load pipeline. Please check the error messages above.")
148
- st.stop()
149
- else:
150
- st.stop()
151
-
152
- if uploaded_file is not None:
153
- # Save uploaded file temporarily
154
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
155
- tmp_file.write(uploaded_file.getvalue())
156
- tmp_path = tmp_file.name
157
-
158
- try:
159
- # Set up progress hook
160
- progress_text = st.empty()
161
- progress_bar = st.progress(0)
162
-
163
- def progress_hook(step: int, total: int, stage: str):
164
- progress_text.text(f"Processing: {stage}")
165
- progress_bar.progress(step / total)
166
-
167
- # Run the pipeline on the audio file
168
- with st.spinner('Processing audio...'):
169
- diarization_args = {
170
- "file": tmp_path,
171
- "min_speakers": min_speakers,
172
- "max_speakers": max_speakers,
173
- "hook": ProgressHook(progress_hook)
174
- }
175
- if num_speakers > 0:
176
- diarization_args["num_speakers"] = num_speakers
177
-
178
- diarization = pipeline(**diarization_args)
179
 
180
- # Rest of the code remains the same...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- # Generate RTTM content
183
- rttm_content = ""
184
- for turn, _, speaker in diarization.itertracks(yield_label=True):
185
- rttm_line = f"SPEAKER {os.path.basename(tmp_path)} 1 {turn.start:.3f} {turn.duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
186
- rttm_content += rttm_line
187
 
188
- # Display RTTM content
189
- st.subheader("Diarization Results (RTTM format)")
190
- st.text_area("RTTM Output", rttm_content, height=300)
191
 
192
- # Provide download button for RTTM file
 
 
 
 
 
 
 
193
  st.download_button(
194
  label="Download RTTM file",
195
  data=rttm_content,
@@ -197,58 +207,34 @@ if uploaded_file is not None:
197
  mime="text/plain"
198
  )
199
 
200
- # Display additional information
201
- st.subheader("Diarization Information")
202
- st.write(f"Number of speakers detected: {len(diarization.labels())}")
203
-
204
- # Visualize diarization
205
  if st.button("Visualize Diarization"):
206
  fig, ax = plt.subplots(figsize=(10, 2))
207
  notebook.plot_diarization(diarization, ax=ax)
208
  plt.tight_layout()
209
  st.pyplot(fig)
210
 
211
- except Exception as e:
212
- st.error(f"An error occurred: {str(e)}")
213
- st.error("Error details:")
214
- st.code(traceback.format_exc())
215
-
216
- # Clean up the temporary file
217
- os.unlink(tmp_path)
218
-
219
- else:
220
- st.info("Please upload an audio file to start.")
221
-
222
- # Display usage instructions
223
- st.sidebar.markdown("""
224
- ## Usage Instructions
225
- 1. Upload an audio file (WAV, MP3, or FLAC).
226
- 2. Adjust advanced options if needed.
227
- 3. Wait for the diarization process to complete.
228
- 4. View and download the RTTM results.
229
- 5. Optionally, visualize the diarization.
230
- """)
231
-
232
- # Display system information
233
- st.sidebar.markdown(f"""
234
- ## System Information
235
- - Python version: {sys.version.split()[0]}
236
- - PyTorch version: {torch.__version__}
237
- - Pyannote Audio version: {pyannote.audio.__version__}
238
- - CUDA available: {torch.cuda.is_available()}
239
- - Device: {"CUDA" if torch.cuda.is_available() else "CPU"}
240
- """)
241
 
242
  # Token Permissions Instructions
243
- st.markdown("""
244
- ## Token Permissions
245
- If you're encountering access issues, please ensure your Hugging Face token has the following permissions:
246
- 1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
247
- 2. Find your token or create a new one
248
- 3. Ensure "Read" access is granted
249
- 4. Check the box for "Access to gated repositories"
250
- 5. Save the changes and try again
251
- """)
252
 
253
  # Clear Cache Button
254
  if st.button("Clear Cache"):
@@ -258,9 +244,4 @@ if st.button("Clear Cache"):
258
  shutil.rmtree(cache_dir)
259
  st.success("Cache cleared successfully.")
260
  else:
261
- st.info("No cache directory found.")
262
-
263
- # Debug Information
264
- st.subheader("Debug Information")
265
- st.write(f"Working directory: {os.getcwd()}")
266
- st.write(f"Files in working directory: {os.listdir()}")
 
14
  import sys
15
  import traceback
16
 
17
+ # Set page configuration
18
+ st.set_page_config(page_title="Optimized Speaker Diarization App", layout="wide")
19
 
20
  st.title("Optimized Speaker Diarization App")
21
 
 
62
  response = requests.get(api_url, headers=headers)
63
  response.raise_for_status()
64
  st.success("Successfully connected to Hugging Face API")
65
+ with st.expander("API Response"):
66
+ st.json(response.json())
67
  except requests.exceptions.RequestException as e:
68
  st.error(f"Error connecting to Hugging Face API: {str(e)}")
69
  if response.status_code == 403:
 
130
 
131
  raise e
132
 
133
+ # Sidebar
134
+ with st.sidebar:
135
+ st.header("Settings")
136
+ show_advanced = st.toggle("Show Advanced Options")
137
+ if show_advanced:
138
+ num_speakers = st.number_input("Number of speakers (0 for auto)", min_value=0, value=0)
139
+ min_speakers = st.number_input("Minimum number of speakers", min_value=1, value=1)
140
+ max_speakers = st.number_input("Maximum number of speakers", min_value=1, value=5)
141
 
142
+ # Main content
143
+ tab1, tab2, tab3 = st.tabs(["Upload & Process", "Results", "Visualization"])
 
 
 
144
 
145
+ with tab1:
146
+ uploaded_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'flac'])
147
+
148
+ if uploaded_file is not None:
149
+ # Save uploaded file temporarily
150
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
151
+ tmp_file.write(uploaded_file.getvalue())
152
+ tmp_path = tmp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ try:
155
+ if verify_token(HF_TOKEN):
156
+ check_hf_api()
157
+ verify_model_files()
158
+ pipeline = load_pipeline()
159
+ else:
160
+ st.stop()
161
+
162
+ with st.status("Processing audio...", expanded=True) as status:
163
+ # Set up progress hook
164
+ def progress_hook(step: int, total: int, stage: str):
165
+ status.update(label=f"Processing: {stage}", state="running")
166
+ st.progress(step / total)
167
+
168
+ # Run the pipeline on the audio file
169
+ diarization_args = {
170
+ "file": tmp_path,
171
+ "min_speakers": min_speakers if show_advanced else 1,
172
+ "max_speakers": max_speakers if show_advanced else 5,
173
+ "hook": ProgressHook(progress_hook)
174
+ }
175
+ if show_advanced and num_speakers > 0:
176
+ diarization_args["num_speakers"] = num_speakers
177
+
178
+ diarization = pipeline(**diarization_args)
179
+ status.update(label="Diarization complete!", state="complete")
180
+
181
+ # Generate RTTM content
182
+ rttm_content = ""
183
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
184
+ rttm_line = f"SPEAKER {os.path.basename(tmp_path)} 1 {turn.start:.3f} {turn.duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n"
185
+ rttm_content += rttm_line
186
 
187
+ except Exception as e:
188
+ st.error(f"An error occurred: {str(e)}")
189
+ st.error("Error details:")
190
+ st.code(traceback.format_exc())
 
191
 
192
+ # Clean up the temporary file
193
+ os.unlink(tmp_path)
 
194
 
195
+ with tab2:
196
+ if 'diarization' in locals():
197
+ st.subheader("Diarization Results")
198
+ st.metric("Number of speakers detected", len(diarization.labels()))
199
+
200
+ with st.expander("RTTM Output"):
201
+ st.text_area("RTTM Content", rttm_content, height=300)
202
+
203
  st.download_button(
204
  label="Download RTTM file",
205
  data=rttm_content,
 
207
  mime="text/plain"
208
  )
209
 
210
+ with tab3:
211
+ if 'diarization' in locals():
 
 
 
212
  if st.button("Visualize Diarization"):
213
  fig, ax = plt.subplots(figsize=(10, 2))
214
  notebook.plot_diarization(diarization, ax=ax)
215
  plt.tight_layout()
216
  st.pyplot(fig)
217
 
218
+ # Debug Information
219
+ with st.expander("Debug Information"):
220
+ st.write(f"Working directory: {os.getcwd()}")
221
+ st.write(f"Files in working directory: {os.listdir()}")
222
+ st.write(f"Python version: {sys.version.split()[0]}")
223
+ st.write(f"PyTorch version: {torch.__version__}")
224
+ st.write(f"Pyannote Audio version: {pyannote.audio.__version__}")
225
+ st.write(f"CUDA available: {torch.cuda.is_available()}")
226
+ st.write(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # Token Permissions Instructions
229
+ with st.expander("Token Permissions"):
230
+ st.markdown("""
231
+ If you're encountering access issues, please ensure your Hugging Face token has the following permissions:
232
+ 1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
233
+ 2. Find your token or create a new one
234
+ 3. Ensure "Read" access is granted
235
+ 4. Check the box for "Access to gated repositories"
236
+ 5. Save the changes and try again
237
+ """)
238
 
239
  # Clear Cache Button
240
  if st.button("Clear Cache"):
 
244
  shutil.rmtree(cache_dir)
245
  st.success("Cache cleared successfully.")
246
  else:
247
+ st.info("No cache directory found.")