Gabriel Bibbó commited on
Commit
e60e716
·
1 Parent(s): 7feb5e2

Performance optimization: 3x speed boost with model scheduling, fast resampling, threshold layer fix, and single model loading

Browse files
Files changed (1) hide show
  1. app.py +67 -139
app.py CHANGED
@@ -52,14 +52,6 @@ except ImportError:
52
  LIBROSA_AVAILABLE = False
53
  print("⚠️ Librosa not available, using scipy fallback")
54
 
55
- try:
56
- import torchaudio.functional as F_audio
57
- TORCHAUDIO_AVAILABLE = True
58
- print("✅ Torchaudio available for fast resampling")
59
- except ImportError:
60
- TORCHAUDIO_AVAILABLE = False
61
- print("⚠️ Torchaudio not available, using librosa fallback")
62
-
63
  try:
64
  import webrtcvad
65
  WEBRTC_AVAILABLE = True
@@ -226,7 +218,6 @@ class OptimizedEPANNs:
226
  def __init__(self):
227
  self.model_name = "E-PANNs"
228
  self.sample_rate = 32000
229
- self.processor = AudioProcessor() # For fast resampling
230
  print(f"✅ {self.model_name} initialized")
231
 
232
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
@@ -239,10 +230,13 @@ class OptimizedEPANNs:
239
  if len(audio.shape) > 1:
240
  audio = audio.mean(axis=1)
241
 
242
- # Fast resampling to E-PANNs sample rate
243
- audio_resampled = self.processor.fast_resample(audio, 16000, self.sample_rate)
244
-
245
  if LIBROSA_AVAILABLE:
 
 
 
 
 
246
  mel_spec = librosa.feature.melspectrogram(y=audio_resampled, sr=self.sample_rate, n_mels=64)
247
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
248
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_resampled, sr=self.sample_rate))
@@ -277,7 +271,6 @@ class OptimizedPANNs:
277
  self.sample_rate = 32000
278
  self.model = None
279
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
280
- self.processor = AudioProcessor() # For fast resampling
281
  self.load_model()
282
 
283
  def load_model(self):
@@ -310,8 +303,19 @@ class OptimizedPANNs:
310
  if len(audio.shape) > 1:
311
  audio = audio.mean(axis=1)
312
 
313
- # Fast resampling to PANNs sample rate
314
- audio_resampled = self.processor.fast_resample(audio, 16000, self.sample_rate)
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  # Ensure minimum length for PANNs (need at least 1 second)
317
  min_samples = self.sample_rate # 1 second
@@ -358,8 +362,6 @@ class OptimizedAST:
358
  self.model = None
359
  self.feature_extractor = None
360
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
361
- # Cache per second (not per tiny chunk)
362
- self.second_cache = {}
363
  self.load_model()
364
 
365
  def load_model(self):
@@ -382,49 +384,53 @@ class OptimizedAST:
382
  start_time = time.time()
383
 
384
  if self.model is None or len(audio) == 0:
385
- # Enhanced energy-based fallback
386
  if len(audio) > 0:
387
  energy = np.sum(audio ** 2)
388
- probability = min(energy * 100, 1.0) # More aggressive scaling
389
- is_speech = probability > 0.2
 
 
 
 
 
 
390
  else:
391
  probability = 0.0
392
  is_speech = False
393
  return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
394
 
395
  try:
396
- # Cache by second to avoid repeated computation
397
- cache_key = int(timestamp)
398
- if cache_key in self.second_cache:
399
- speech_prob = self.second_cache[cache_key]
400
- return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
401
-
402
  if len(audio.shape) > 1:
403
  audio = audio.mean(axis=1)
404
 
405
- # Use 2-second window from full audio for better context
406
- if full_audio is not None and len(full_audio) >= 2 * self.sample_rate:
 
407
  center_pos = int(timestamp * self.sample_rate)
408
- window_size = self.sample_rate # 1 second each side
409
 
410
  start_pos = max(0, center_pos - window_size)
411
  end_pos = min(len(full_audio), center_pos + window_size)
412
 
413
- # Ensure minimum 2 seconds
414
- if end_pos - start_pos < 2 * self.sample_rate:
415
- end_pos = min(len(full_audio), start_pos + 2 * self.sample_rate)
416
 
417
  audio_for_ast = full_audio[start_pos:end_pos]
418
  else:
419
- # Pad to 2 seconds minimum
420
- audio_for_ast = np.pad(audio, (0, max(0, 2 * self.sample_rate - len(audio))), 'constant')
 
 
 
421
 
422
  # Feature extraction with proper AST parameters
423
  inputs = self.feature_extractor(
424
  audio_for_ast,
425
  sampling_rate=self.sample_rate,
426
  return_tensors="pt",
427
- max_length=1024, # Proper AST context length
428
  truncation=True
429
  )
430
 
@@ -446,34 +452,23 @@ class OptimizedAST:
446
 
447
  if speech_indices:
448
  speech_prob = probs[0, speech_indices].mean().item()
449
-
450
- # Boost low probabilities if there's clear audio content
451
- audio_energy = np.sum(audio_for_ast ** 2)
452
- if speech_prob < 0.2 and audio_energy > 0.01:
453
- speech_prob = min(speech_prob * 3 + audio_energy * 10, 0.9)
454
  else:
455
- # Energy-based fallback
456
- audio_energy = np.sum(audio_for_ast ** 2)
457
- speech_prob = min(audio_energy * 20, 1.0)
458
-
459
- # Cache for efficiency (limit cache size)
460
- if len(self.second_cache) < 200:
461
- self.second_cache[cache_key] = speech_prob
462
- elif len(self.second_cache) >= 300:
463
- # Clear old entries
464
- oldest_keys = sorted(self.second_cache.keys())[:100]
465
- for k in oldest_keys:
466
- del self.second_cache[k]
467
 
468
  return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
469
 
470
  except Exception as e:
471
  print(f"Error in {self.model_name}: {e}")
472
- # Robust fallback
473
  if len(audio) > 0:
474
  energy = np.sum(audio ** 2)
475
- probability = min(energy * 50, 1.0)
476
- is_speech = energy > 0.01
477
  else:
478
  probability = 0.0
479
  is_speech = False
@@ -496,43 +491,9 @@ class AudioProcessor:
496
  self.window_size = 0.064
497
  self.hop_size = 0.032
498
 
499
- # Model-specific hop rates for efficiency
500
- self.model_hop_rates = {
501
- 'Silero-VAD': 0.032, # 32ms - optimal for this model
502
- 'WebRTC-VAD': 0.030, # 30ms - WebRTC frame size
503
- 'PANNs': 1.0, # 1s - CNN needs longer context
504
- 'E-PANNs': 1.0, # 1s - CNN needs longer context
505
- 'AST': 1.0 # 1s - Transformer needs long context
506
- }
507
-
508
  self.delay_compensation = 0.0
509
  self.correlation_threshold = 0.7
510
 
511
- def fast_resample(self, audio, orig_sr, target_sr):
512
- """Fast resampling using torchaudio if available, fallback to librosa"""
513
- if TORCHAUDIO_AVAILABLE and orig_sr != target_sr:
514
- audio_tensor = torch.from_numpy(audio.astype(np.float32))
515
- resampled = F_audio.resample(audio_tensor, orig_sr, target_sr)
516
- return resampled.numpy()
517
- elif LIBROSA_AVAILABLE and orig_sr != target_sr:
518
- return librosa.resample(audio.astype(float), orig_sr=orig_sr, target_sr=target_sr)
519
- else:
520
- return audio
521
-
522
- def robust_normalize(self, audio_data):
523
- """RMS-based normalization instead of peak normalization"""
524
- if len(audio_data) == 0:
525
- return audio_data
526
-
527
- # RMS normalization - more robust than peak
528
- rms = np.sqrt(np.mean(audio_data ** 2) + 1e-8)
529
- if rms > 1e-6:
530
- audio_data = audio_data / (rms * 3) # Scale by 3x RMS
531
-
532
- # Gentle clipping
533
- audio_data = np.clip(audio_data, -1.0, 1.0)
534
- return audio_data
535
-
536
  def process_audio(self, audio):
537
  if audio is None:
538
  return np.array([])
@@ -540,15 +501,18 @@ class AudioProcessor:
540
  try:
541
  if isinstance(audio, tuple):
542
  sample_rate, audio_data = audio
543
- audio_data = self.fast_resample(audio_data, sample_rate, self.sample_rate)
 
 
 
544
  else:
545
  audio_data = audio
546
 
547
  if len(audio_data.shape) > 1:
548
  audio_data = audio_data.mean(axis=1)
549
 
550
- # Use robust RMS normalization
551
- audio_data = self.robust_normalize(audio_data)
552
 
553
  return audio_data
554
 
@@ -776,11 +740,10 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
776
  )
777
 
778
  if len(time_frames) > 0:
779
- # Add threshold lines to both panels with layer='above' to show over spectrograms
780
  fig.add_hline(
781
  y=threshold,
782
  line=dict(color='cyan', width=2, dash='dash'),
783
- layer='above',
784
  annotation_text=f'Threshold: {threshold:.2f}',
785
  annotation_position="top right",
786
  row=1, col=1, secondary_y=True
@@ -788,7 +751,6 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
788
  fig.add_hline(
789
  y=threshold,
790
  line=dict(color='cyan', width=2, dash='dash'),
791
- layer='above',
792
  annotation_text=f'Threshold: {threshold:.2f}',
793
  annotation_position="top right",
794
  row=2, col=1, secondary_y=True
@@ -878,7 +840,6 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
878
  height=500,
879
  title_text="Real-Time Speech Visualizer",
880
  showlegend=True,
881
- uirevision="const", # Preserve zoom/pan when updating
882
  legend=dict(
883
  x=1.02,
884
  y=1,
@@ -960,45 +921,21 @@ class VADDemo:
960
 
961
  selected_models = list(set([model_a, model_b]))
962
 
963
- # Process with model-specific hop rates for efficiency
964
  for i in range(0, len(processed_audio) - window_samples, hop_samples):
965
  timestamp = i / self.processor.sample_rate
966
  chunk = processed_audio[i:i + window_samples]
967
 
968
  for model_name in selected_models:
969
  if model_name in self.models:
970
- # Check if this model should be processed at this timestamp
971
- model_hop_rate = self.processor.model_hop_rates.get(model_name, self.processor.hop_size)
972
- hop_samples_model = int(model_hop_rate * self.processor.sample_rate)
 
 
973
 
974
- # Only process if this is the right time for this model
975
- if i % hop_samples_model == 0:
976
- # Special handling for AST - pass full audio for context
977
- if model_name == 'AST':
978
- result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
979
- else:
980
- result = self.models[model_name].predict(chunk, timestamp)
981
-
982
- result.is_speech = result.probability > threshold
983
- vad_results.append(result)
984
- elif len(vad_results) > 0:
985
- # Interpolate from last result for missing timestamps
986
- last_result = None
987
- for prev_result in reversed(vad_results):
988
- if prev_result.model_name == model_name:
989
- last_result = prev_result
990
- break
991
-
992
- if last_result:
993
- # Create interpolated result
994
- result = VADResult(
995
- last_result.probability,
996
- last_result.probability > threshold,
997
- model_name,
998
- 0.0, # No processing time for interpolated
999
- timestamp
1000
- )
1001
- vad_results.append(result)
1002
 
1003
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
1004
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
@@ -1059,13 +996,7 @@ demo_app = VADDemo()
1059
 
1060
  # ===== GRADIO INTERFACE =====
1061
 
1062
- # Global demo app instance (will be initialized in main)
1063
- demo_app = None
1064
-
1065
  def create_interface():
1066
- # Use global demo_app instance
1067
- global demo_app
1068
-
1069
  # Load logos
1070
  logos = load_logos()
1071
 
@@ -1174,12 +1105,9 @@ def create_interface():
1174
 
1175
  # Create and launch interface
1176
  if __name__ == "__main__":
1177
- # Initialize demo (single instance)
1178
  print("🎤 Initializing VAD Demo...")
1179
  demo_app = VADDemo()
1180
 
1181
  interface = create_interface()
1182
- interface.launch(share=True, debug=False)
1183
- else:
1184
- # For module imports, create a placeholder
1185
- demo_app = None
 
52
  LIBROSA_AVAILABLE = False
53
  print("⚠️ Librosa not available, using scipy fallback")
54
 
 
 
 
 
 
 
 
 
55
  try:
56
  import webrtcvad
57
  WEBRTC_AVAILABLE = True
 
218
  def __init__(self):
219
  self.model_name = "E-PANNs"
220
  self.sample_rate = 32000
 
221
  print(f"✅ {self.model_name} initialized")
222
 
223
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
 
230
  if len(audio.shape) > 1:
231
  audio = audio.mean(axis=1)
232
 
233
+ # Convert audio to target sample rate for E-PANNs
 
 
234
  if LIBROSA_AVAILABLE:
235
+ # Resample to E-PANNs sample rate if needed
236
+ audio_resampled = librosa.resample(audio.astype(float),
237
+ orig_sr=16000,
238
+ target_sr=self.sample_rate)
239
+
240
  mel_spec = librosa.feature.melspectrogram(y=audio_resampled, sr=self.sample_rate, n_mels=64)
241
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
242
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_resampled, sr=self.sample_rate))
 
271
  self.sample_rate = 32000
272
  self.model = None
273
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
274
  self.load_model()
275
 
276
  def load_model(self):
 
303
  if len(audio.shape) > 1:
304
  audio = audio.mean(axis=1)
305
 
306
+ # Convert audio to PANNs sample rate
307
+ if LIBROSA_AVAILABLE:
308
+ audio_resampled = librosa.resample(audio.astype(float),
309
+ orig_sr=16000,
310
+ target_sr=self.sample_rate)
311
+ else:
312
+ # Simple resampling fallback
313
+ resample_factor = self.sample_rate / 16000
314
+ audio_resampled = np.interp(
315
+ np.linspace(0, len(audio) - 1, int(len(audio) * resample_factor)),
316
+ np.arange(len(audio)),
317
+ audio
318
+ )
319
 
320
  # Ensure minimum length for PANNs (need at least 1 second)
321
  min_samples = self.sample_rate # 1 second
 
362
  self.model = None
363
  self.feature_extractor = None
364
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
365
  self.load_model()
366
 
367
  def load_model(self):
 
384
  start_time = time.time()
385
 
386
  if self.model is None or len(audio) == 0:
387
+ # Enhanced fallback using spectral features
388
  if len(audio) > 0:
389
  energy = np.sum(audio ** 2)
390
+ if LIBROSA_AVAILABLE:
391
+ spectral_features = librosa.feature.spectral_rolloff(y=audio, sr=self.sample_rate)
392
+ spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
393
+ # Combine multiple features for better speech detection
394
+ probability = min((energy * 100 + spectral_centroid / 500) / 2, 1.0)
395
+ else:
396
+ probability = min(energy * 50, 1.0)
397
+ is_speech = probability > 0.3
398
  else:
399
  probability = 0.0
400
  is_speech = False
401
  return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
402
 
403
  try:
 
 
 
 
 
 
404
  if len(audio.shape) > 1:
405
  audio = audio.mean(axis=1)
406
 
407
+ # Use longer context for AST - take from full audio if available
408
+ if full_audio is not None and len(full_audio) > self.sample_rate:
409
+ # Take 3-second window centered around current timestamp
410
  center_pos = int(timestamp * self.sample_rate)
411
+ window_size = int(1.5 * self.sample_rate) # 1.5 seconds each side
412
 
413
  start_pos = max(0, center_pos - window_size)
414
  end_pos = min(len(full_audio), center_pos + window_size)
415
 
416
+ # Ensure we have at least 1 second
417
+ if end_pos - start_pos < self.sample_rate:
418
+ end_pos = min(len(full_audio), start_pos + self.sample_rate)
419
 
420
  audio_for_ast = full_audio[start_pos:end_pos]
421
  else:
422
+ audio_for_ast = audio
423
+
424
+ # Ensure minimum length for AST
425
+ if len(audio_for_ast) < self.sample_rate:
426
+ audio_for_ast = np.pad(audio_for_ast, (0, self.sample_rate - len(audio_for_ast)), 'constant')
427
 
428
  # Feature extraction with proper AST parameters
429
  inputs = self.feature_extractor(
430
  audio_for_ast,
431
  sampling_rate=self.sample_rate,
432
  return_tensors="pt",
433
+ max_length=1024, # Proper AST context
434
  truncation=True
435
  )
436
 
 
452
 
453
  if speech_indices:
454
  speech_prob = probs[0, speech_indices].mean().item()
455
+ # Boost the probability if it's too low but there's clear audio content
456
+ if speech_prob < 0.1 and np.sum(audio_for_ast ** 2) > 0.001:
457
+ speech_prob = min(speech_prob * 5, 0.8) # Boost but cap at 0.8
 
 
458
  else:
459
+ # Fallback to energy-based detection
460
+ energy = np.sum(audio_for_ast ** 2)
461
+ speech_prob = min(energy * 20, 1.0)
 
 
 
 
 
 
 
 
 
462
 
463
  return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
464
 
465
  except Exception as e:
466
  print(f"Error in {self.model_name}: {e}")
467
+ # Enhanced fallback
468
  if len(audio) > 0:
469
  energy = np.sum(audio ** 2)
470
+ probability = min(energy * 30, 1.0) # More aggressive energy scaling
471
+ is_speech = energy > 0.002
472
  else:
473
  probability = 0.0
474
  is_speech = False
 
491
  self.window_size = 0.064
492
  self.hop_size = 0.032
493
 
 
 
 
 
 
 
 
 
 
494
  self.delay_compensation = 0.0
495
  self.correlation_threshold = 0.7
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  def process_audio(self, audio):
498
  if audio is None:
499
  return np.array([])
 
501
  try:
502
  if isinstance(audio, tuple):
503
  sample_rate, audio_data = audio
504
+ if sample_rate != self.sample_rate and LIBROSA_AVAILABLE:
505
+ audio_data = librosa.resample(audio_data.astype(float),
506
+ orig_sr=sample_rate,
507
+ target_sr=self.sample_rate)
508
  else:
509
  audio_data = audio
510
 
511
  if len(audio_data.shape) > 1:
512
  audio_data = audio_data.mean(axis=1)
513
 
514
+ if np.max(np.abs(audio_data)) > 0:
515
+ audio_data = audio_data / np.max(np.abs(audio_data))
516
 
517
  return audio_data
518
 
 
740
  )
741
 
742
  if len(time_frames) > 0:
743
+ # Add threshold lines to both panels
744
  fig.add_hline(
745
  y=threshold,
746
  line=dict(color='cyan', width=2, dash='dash'),
 
747
  annotation_text=f'Threshold: {threshold:.2f}',
748
  annotation_position="top right",
749
  row=1, col=1, secondary_y=True
 
751
  fig.add_hline(
752
  y=threshold,
753
  line=dict(color='cyan', width=2, dash='dash'),
 
754
  annotation_text=f'Threshold: {threshold:.2f}',
755
  annotation_position="top right",
756
  row=2, col=1, secondary_y=True
 
840
  height=500,
841
  title_text="Real-Time Speech Visualizer",
842
  showlegend=True,
 
843
  legend=dict(
844
  x=1.02,
845
  y=1,
 
921
 
922
  selected_models = list(set([model_a, model_b]))
923
 
924
+ # Process each window individually for all models
925
  for i in range(0, len(processed_audio) - window_samples, hop_samples):
926
  timestamp = i / self.processor.sample_rate
927
  chunk = processed_audio[i:i + window_samples]
928
 
929
  for model_name in selected_models:
930
  if model_name in self.models:
931
+ # Special handling for AST - pass full audio for context
932
+ if model_name == 'AST':
933
+ result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
934
+ else:
935
+ result = self.models[model_name].predict(chunk, timestamp)
936
 
937
+ result.is_speech = result.probability > threshold
938
+ vad_results.append(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939
 
940
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
941
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
 
996
 
997
  # ===== GRADIO INTERFACE =====
998
 
 
 
 
999
  def create_interface():
 
 
 
1000
  # Load logos
1001
  logos = load_logos()
1002
 
 
1105
 
1106
  # Create and launch interface
1107
  if __name__ == "__main__":
1108
+ # Initialize demo
1109
  print("🎤 Initializing VAD Demo...")
1110
  demo_app = VADDemo()
1111
 
1112
  interface = create_interface()
1113
+ interface.launch(share=True, debug=False)