Gabriel Bibbó commited on
Commit
5bbaead
·
1 Parent(s): bcae560

Hotfix: Restore basic functionality - fix AST saturation and PANNs execution

Browse files
Files changed (1) hide show
  1. app.py +79 -37
app.py CHANGED
@@ -362,6 +362,8 @@ class OptimizedAST:
362
  self.model = None
363
  self.feature_extractor = None
364
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
365
  self.load_model()
366
 
367
  def load_model(self):
@@ -401,29 +403,52 @@ class OptimizedAST:
401
  return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
402
 
403
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  if len(audio.shape) > 1:
405
  audio = audio.mean(axis=1)
406
 
407
- # Use longer context for AST - take from full audio if available
408
- if full_audio is not None and len(full_audio) > self.sample_rate:
409
- # Take 3-second window centered around current timestamp
410
  center_pos = int(timestamp * self.sample_rate)
411
- window_size = int(1.5 * self.sample_rate) # 1.5 seconds each side
412
 
413
  start_pos = max(0, center_pos - window_size)
414
  end_pos = min(len(full_audio), center_pos + window_size)
415
 
416
- # Ensure we have at least 1 second
417
- if end_pos - start_pos < self.sample_rate:
418
- end_pos = min(len(full_audio), start_pos + self.sample_rate)
 
 
419
 
420
  audio_for_ast = full_audio[start_pos:end_pos]
421
  else:
422
  audio_for_ast = audio
423
 
424
- # Ensure minimum length for AST
425
- if len(audio_for_ast) < self.sample_rate:
426
- audio_for_ast = np.pad(audio_for_ast, (0, self.sample_rate - len(audio_for_ast)), 'constant')
 
 
 
 
 
 
427
 
428
  # Feature extraction with proper AST parameters
429
  inputs = self.feature_extractor(
@@ -452,23 +477,33 @@ class OptimizedAST:
452
 
453
  if speech_indices:
454
  speech_prob = probs[0, speech_indices].mean().item()
455
- # Boost the probability if it's too low but there's clear audio content
456
  if speech_prob < 0.1 and np.sum(audio_for_ast ** 2) > 0.001:
457
- speech_prob = min(speech_prob * 5, 0.8) # Boost but cap at 0.8
458
  else:
459
- # Fallback to energy-based detection
460
- energy = np.sum(audio_for_ast ** 2)
461
- speech_prob = min(energy * 20, 1.0)
462
 
463
- return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
 
 
 
 
 
 
 
 
 
 
464
 
465
  except Exception as e:
466
  print(f"Error in {self.model_name}: {e}")
467
  # Enhanced fallback
468
  if len(audio) > 0:
469
- energy = np.sum(audio ** 2)
470
- probability = min(energy * 30, 1.0) # More aggressive energy scaling
471
- is_speech = energy > 0.002
472
  else:
473
  probability = 0.0
474
  is_speech = False
@@ -491,6 +526,15 @@ class AudioProcessor:
491
  self.window_size = 0.064
492
  self.hop_size = 0.032
493
 
 
 
 
 
 
 
 
 
 
494
  self.delay_compensation = 0.0
495
  self.correlation_threshold = 0.7
496
 
@@ -921,21 +965,24 @@ class VADDemo:
921
 
922
  selected_models = list(set([model_a, model_b]))
923
 
924
- # Process each window individually for all models
925
  for i in range(0, len(processed_audio) - window_samples, hop_samples):
926
  timestamp = i / self.processor.sample_rate
927
  chunk = processed_audio[i:i + window_samples]
928
 
929
  for model_name in selected_models:
930
  if model_name in self.models:
931
- # Special handling for AST - pass full audio for context
932
- if model_name == 'AST':
933
- result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
934
- else:
935
- result = self.models[model_name].predict(chunk, timestamp)
936
-
937
- result.is_speech = result.probability > threshold
938
- vad_results.append(result)
 
 
 
939
 
940
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
941
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
@@ -990,10 +1037,6 @@ class VADDemo:
990
  traceback.print_exc()
991
  return None, f"❌ Error: {str(e)}", f"Error details: {traceback.format_exc()}"
992
 
993
- # Initialize demo
994
- print("🎤 Initializing VAD Demo...")
995
- demo_app = VADDemo()
996
-
997
  # ===== GRADIO INTERFACE =====
998
 
999
  def create_interface():
@@ -1053,7 +1096,7 @@ def create_interface():
1053
 
1054
  model_b = gr.Dropdown(
1055
  choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
1056
- value="PANNs",
1057
  label="Model B (Bottom Panel)"
1058
  )
1059
 
@@ -1103,11 +1146,10 @@ def create_interface():
1103
 
1104
  return interface
1105
 
 
 
 
1106
  # Create and launch interface
1107
  if __name__ == "__main__":
1108
- # Initialize demo
1109
- print("🎤 Initializing VAD Demo...")
1110
- demo_app = VADDemo()
1111
-
1112
  interface = create_interface()
1113
  interface.launch(share=True, debug=False)
 
362
  self.model = None
363
  self.feature_extractor = None
364
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
365
+ self.prediction_cache = {} # Cache para evitar recálculos
366
+ self.cache_window = 1.0 # Cachear resultados por segundo
367
  self.load_model()
368
 
369
  def load_model(self):
 
403
  return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
404
 
405
  try:
406
+ # Cache key based on timestamp rounded to cache window
407
+ cache_key = int(timestamp / self.cache_window)
408
+
409
+ # Check cache first
410
+ if cache_key in self.prediction_cache:
411
+ cached_result = self.prediction_cache[cache_key]
412
+ # Return cached result with updated timestamp
413
+ return VADResult(
414
+ cached_result.probability,
415
+ cached_result.is_speech,
416
+ cached_result.model_name + " (cached)",
417
+ time.time() - start_time,
418
+ timestamp
419
+ )
420
+
421
  if len(audio.shape) > 1:
422
  audio = audio.mean(axis=1)
423
 
424
+ # Use longer context for AST - preferably 2 seconds
425
+ if full_audio is not None and len(full_audio) >= 2 * self.sample_rate:
426
+ # Take 2-second window centered around current timestamp
427
  center_pos = int(timestamp * self.sample_rate)
428
+ window_size = self.sample_rate # 1 second each side
429
 
430
  start_pos = max(0, center_pos - window_size)
431
  end_pos = min(len(full_audio), center_pos + window_size)
432
 
433
+ # Ensure we have at least 2 seconds
434
+ if end_pos - start_pos < 2 * self.sample_rate:
435
+ end_pos = min(len(full_audio), start_pos + 2 * self.sample_rate)
436
+ if end_pos - start_pos < 2 * self.sample_rate:
437
+ start_pos = max(0, end_pos - 2 * self.sample_rate)
438
 
439
  audio_for_ast = full_audio[start_pos:end_pos]
440
  else:
441
  audio_for_ast = audio
442
 
443
+ # Ensure minimum length for AST (2 seconds preferred, minimum 1 second)
444
+ min_samples = 2 * self.sample_rate # 2 seconds
445
+ if len(audio_for_ast) < min_samples:
446
+ audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), 'constant')
447
+
448
+ # Truncate if too long (AST can handle up to ~10s, but we'll use 3s max for efficiency)
449
+ max_samples = 3 * self.sample_rate
450
+ if len(audio_for_ast) > max_samples:
451
+ audio_for_ast = audio_for_ast[:max_samples]
452
 
453
  # Feature extraction with proper AST parameters
454
  inputs = self.feature_extractor(
 
477
 
478
  if speech_indices:
479
  speech_prob = probs[0, speech_indices].mean().item()
480
+ # Apply more reasonable thresholding for AST
481
  if speech_prob < 0.1 and np.sum(audio_for_ast ** 2) > 0.001:
482
+ speech_prob = min(speech_prob * 3, 0.7) # Moderate boost, cap at 0.7
483
  else:
484
+ # Fallback to energy-based detection with higher threshold
485
+ energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
486
+ speech_prob = min(energy * 50, 1.0)
487
 
488
+ result = VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
489
+
490
+ # Cache the result
491
+ self.prediction_cache[cache_key] = result
492
+
493
+ # Clean old cache entries (keep only last 10 seconds)
494
+ cache_keys_to_remove = [k for k in self.prediction_cache.keys() if k < cache_key - 10]
495
+ for k in cache_keys_to_remove:
496
+ del self.prediction_cache[k]
497
+
498
+ return result
499
 
500
  except Exception as e:
501
  print(f"Error in {self.model_name}: {e}")
502
  # Enhanced fallback
503
  if len(audio) > 0:
504
+ energy = np.sum(audio ** 2) / len(audio) # Normalize by length
505
+ probability = min(energy * 100, 1.0) # More conservative scaling
506
+ is_speech = energy > 0.001 # Lower threshold for fallback
507
  else:
508
  probability = 0.0
509
  is_speech = False
 
526
  self.window_size = 0.064
527
  self.hop_size = 0.032
528
 
529
+ # Model-specific hop sizes for efficiency
530
+ self.model_hop_sizes = {
531
+ "Silero-VAD": 0.032,
532
+ "WebRTC-VAD": 0.03,
533
+ "E-PANNs": 1.0,
534
+ "PANNs": 1.0,
535
+ "AST": 1.0 # Process AST only once per second
536
+ }
537
+
538
  self.delay_compensation = 0.0
539
  self.correlation_threshold = 0.7
540
 
 
965
 
966
  selected_models = list(set([model_a, model_b]))
967
 
968
+ # Process each window with model-specific hop sizes for efficiency
969
  for i in range(0, len(processed_audio) - window_samples, hop_samples):
970
  timestamp = i / self.processor.sample_rate
971
  chunk = processed_audio[i:i + window_samples]
972
 
973
  for model_name in selected_models:
974
  if model_name in self.models:
975
+ # Check if we should process this model at this timestamp
976
+ model_hop = self.processor.model_hop_sizes.get(model_name, self.processor.hop_size)
977
+ if i % int(model_hop * self.processor.sample_rate) == 0:
978
+ # Special handling for AST - pass full audio for context
979
+ if model_name == 'AST':
980
+ result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
981
+ else:
982
+ result = self.models[model_name].predict(chunk, timestamp)
983
+
984
+ result.is_speech = result.probability > threshold
985
+ vad_results.append(result)
986
 
987
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
988
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
 
1037
  traceback.print_exc()
1038
  return None, f"❌ Error: {str(e)}", f"Error details: {traceback.format_exc()}"
1039
 
 
 
 
 
1040
  # ===== GRADIO INTERFACE =====
1041
 
1042
  def create_interface():
 
1096
 
1097
  model_b = gr.Dropdown(
1098
  choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
1099
+ value="AST",
1100
  label="Model B (Bottom Panel)"
1101
  )
1102
 
 
1146
 
1147
  return interface
1148
 
1149
+ # Initialize demo only once
1150
+ demo_app = VADDemo()
1151
+
1152
  # Create and launch interface
1153
  if __name__ == "__main__":
 
 
 
 
1154
  interface = create_interface()
1155
  interface.launch(share=True, debug=False)