Gabriel Bibbó commited on
Commit
aee7b20
·
1 Parent(s): d02d086

adjust app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -293
app.py CHANGED
@@ -260,6 +260,7 @@ class OptimizedEPANNs:
260
  def __init__(self):
261
  self.model_name = "E-PANNs"
262
  self.sample_rate = 32000
 
263
  print(f"✅ {self.model_name} initialized")
264
 
265
  # Try to load PANNs AudioTagging as backend for E-PANNs
@@ -281,92 +282,50 @@ class OptimizedEPANNs:
281
  if len(audio.shape) > 1:
282
  audio = audio.mean(axis=1)
283
 
284
- # For E-PANNs, we need to extract the appropriate window based on timestamp
285
- window_duration = 6.0 # 6 seconds window for E-PANNs
286
- window_samples = int(window_duration * 16000) # at 16kHz input rate
287
 
288
- # Calculate the center position for this timestamp
289
- center_sample = int(timestamp * 16000)
290
- half_window = window_samples // 2
 
291
 
292
- # Extract window centered at timestamp
293
- start_idx = max(0, center_sample - half_window)
294
- end_idx = min(len(audio), start_idx + window_samples)
295
-
296
- # Adjust start if we're at the end of audio
297
- if end_idx == len(audio) and end_idx - start_idx < window_samples:
298
- start_idx = max(0, end_idx - window_samples)
299
-
300
- audio_window = audio[start_idx:end_idx]
301
-
302
- # Convert audio to target sample rate for E-PANNs (32kHz)
303
- if LIBROSA_AVAILABLE:
304
- # Resample to E-PANNs sample rate
305
- audio_resampled = librosa.resample(audio_window.astype(float),
306
- orig_sr=16000,
307
- target_sr=self.sample_rate)
308
 
309
- # For short audio, repeat it instead of padding with zeros
310
- min_samples = 6 * self.sample_rate # 6 seconds
311
- if len(audio_resampled) < min_samples:
312
- # Repeat the audio to fill the minimum required length
313
- num_repeats = int(np.ceil(min_samples / len(audio_resampled)))
314
- audio_resampled = np.tile(audio_resampled, num_repeats)[:min_samples]
315
 
316
- # If we have PANNs AT model, use it
317
- if self.at_model is not None:
318
- # Run inference
319
- clipwise_output, _ = self.at_model.inference(audio_resampled[np.newaxis, :])
320
-
321
- # Get speech-related classes
322
- speech_keywords = [
323
- 'speech', 'voice', 'talk', 'conversation', 'speaking',
324
- 'male speech', 'female speech', 'child speech',
325
- 'narration', 'monologue'
326
- ]
327
-
328
- speech_indices = []
329
- for i, lbl in enumerate(labels):
330
- if any(word in lbl.lower() for word in speech_keywords):
331
- speech_indices.append(i)
332
-
333
- if speech_indices:
334
- speech_probs = clipwise_output[0, speech_indices]
335
- speech_score = float(np.max(speech_probs))
336
- else:
337
- speech_score = float(np.max(clipwise_output[0]))
338
  else:
339
- # Fallback to spectral features
340
- # Compute features
341
- mel_spec = librosa.feature.melspectrogram(y=audio_resampled, sr=self.sample_rate, n_mels=64)
 
 
342
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
 
343
 
344
- # Use actual non-repeated audio for some features
345
- actual_audio_len = min(len(audio_resampled), int(len(audio_window) * self.sample_rate / 16000))
346
- actual_audio = audio_resampled[:actual_audio_len]
347
-
348
- spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=actual_audio, sr=self.sample_rate))
349
- mfcc = librosa.feature.mfcc(y=actual_audio, sr=self.sample_rate, n_mfcc=13)
350
- mfcc_var = np.var(mfcc, axis=1).mean()
351
- zcr = np.mean(librosa.feature.zero_crossing_rate(actual_audio))
352
-
353
- # Adjusted scaling for better speech detection
354
  energy_score = np.clip((energy + 80) / 40, 0, 1)
355
  centroid_score = np.clip((spectral_centroid - 200) / 3000, 0, 1)
356
- mfcc_score = np.clip(mfcc_var / 100, 0, 1)
357
- zcr_score = np.clip(zcr * 10, 0, 1)
358
-
359
- # Weighted combination
360
- speech_score = (energy_score * 0.4 +
361
- centroid_score * 0.2 +
362
- mfcc_score * 0.3 +
363
- zcr_score * 0.1)
364
- else:
365
- from scipy import signal
366
- # Basic fallback without librosa
367
- f, t, Sxx = signal.spectrogram(audio_window, 16000)
368
- energy = np.mean(10 * np.log10(Sxx + 1e-10))
369
- speech_score = np.clip((energy + 100) / 50, 0, 1)
370
 
371
  probability = np.clip(speech_score, 0, 1)
372
  is_speech = probability > 0.4
@@ -432,54 +391,25 @@ class OptimizedPANNs:
432
  if len(audio.shape) > 1:
433
  audio = audio.mean(axis=1)
434
 
435
- # For PANNs, extract the appropriate window based on timestamp
436
- window_duration = 10.0 # 10 seconds window for PANNs
437
- window_samples = int(window_duration * 16000) # at 16kHz input rate
438
-
439
- # Calculate the center position for this timestamp
440
- center_sample = int(timestamp * 16000)
441
- half_window = window_samples // 2
442
-
443
- # Extract window centered at timestamp
444
- start_idx = max(0, center_sample - half_window)
445
- end_idx = min(len(audio), start_idx + window_samples)
446
-
447
- # Adjust start if we're at the end of audio
448
- if end_idx == len(audio) and end_idx - start_idx < window_samples:
449
- start_idx = max(0, end_idx - window_samples)
450
-
451
- audio_window = audio[start_idx:end_idx]
452
-
453
  # Convert audio to PANNs sample rate
454
  if LIBROSA_AVAILABLE:
455
- audio_resampled = librosa.resample(audio_window.astype(float),
456
  orig_sr=16000,
457
  target_sr=self.sample_rate)
458
  else:
459
  # Simple resampling fallback
460
  resample_factor = self.sample_rate / 16000
461
  audio_resampled = np.interp(
462
- np.linspace(0, len(audio_window) - 1, int(len(audio_window) * resample_factor)),
463
- np.arange(len(audio_window)),
464
- audio_window
465
  )
466
 
467
- # For short audio, use intelligent padding strategy
468
- min_samples = 10 * self.sample_rate # 10 seconds for optimal performance
469
  if len(audio_resampled) < min_samples:
470
- # Strategy: repeat the audio cyclically to maintain characteristics
471
- num_repeats = int(np.ceil(min_samples / len(audio_resampled)))
472
- audio_repeated = np.tile(audio_resampled, num_repeats)[:min_samples]
473
-
474
- # Apply fade in/out to reduce artifacts
475
- fade_len = int(0.1 * self.sample_rate) # 100ms fade
476
- fade_in = np.linspace(0, 1, fade_len)
477
- fade_out = np.linspace(1, 0, fade_len)
478
-
479
- audio_repeated[:fade_len] *= fade_in
480
- audio_repeated[-fade_len:] *= fade_out
481
-
482
- audio_resampled = audio_repeated
483
 
484
  # Use SED for framewise predictions if available
485
  if self.sed_model is not None:
@@ -492,13 +422,8 @@ class OptimizedPANNs:
492
  if framewise_output.ndim == 3:
493
  framewise_output = framewise_output[0] # Remove batch dimension
494
 
495
- # Get frame corresponding to timestamp
496
- audio_duration = len(audio_resampled) / self.sample_rate
497
- if audio_duration > 0:
498
- frame_idx = int((timestamp % audio_duration) / audio_duration * framewise_output.shape[0])
499
- frame_idx = min(frame_idx, framewise_output.shape[0] - 1)
500
- else:
501
- frame_idx = 0
502
 
503
  # Get speech-related classes
504
  speech_keywords = [
@@ -551,11 +476,6 @@ class OptimizedPANNs:
551
  noise_prob = np.mean(clip_probs[0, noise_indices])
552
  # Adjust speech probability based on noise
553
  speech_prob = speech_prob * (1 - noise_prob * 0.5)
554
-
555
- # If using repeated audio, scale confidence based on original length
556
- if len(audio_window) < 16000 * 2: # Less than 2 seconds
557
- confidence_scale = len(audio_window) / (16000 * 2)
558
- speech_prob = speech_prob * (0.5 + 0.5 * confidence_scale)
559
 
560
  else:
561
  # Fallback if no speech indices found
@@ -579,15 +499,14 @@ class OptimizedPANNs:
579
  return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
580
 
581
  class OptimizedAST:
582
- """CORRECTED AST with proper 16kHz sample rate and sliding windows"""
583
  def __init__(self):
584
  self.model_name = "AST"
585
  self.sample_rate = 16000 # AST REQUIRES 16kHz
586
  self.model = None
587
  self.feature_extractor = None
588
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
589
- self.prediction_cache = {} # Cache para evitar recálculos
590
- self.cache_window = 1.0 # Cachear resultados por segundo
591
  self.load_model()
592
 
593
  def load_model(self):
@@ -616,12 +535,7 @@ class OptimizedAST:
616
  def predict(self, audio: np.ndarray, timestamp: float = 0.0, full_audio: np.ndarray = None) -> VADResult:
617
  start_time = time.time()
618
 
619
- print(f"🔍 AST predict: audio_len={len(audio)}, timestamp={timestamp:.2f}, model_available={self.model is not None}")
620
- if full_audio is not None:
621
- print(f"🔍 AST: full_audio_len={len(full_audio)}")
622
-
623
  if self.model is None or len(audio) == 0:
624
- print(f"❌ AST: Model unavailable or empty audio")
625
  # Enhanced fallback using spectral features
626
  if len(audio) > 0:
627
  energy = np.sum(audio ** 2)
@@ -630,10 +544,8 @@ class OptimizedAST:
630
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
631
  # Combine multiple features for better speech detection
632
  probability = min((energy * 100 + spectral_centroid / 1000) / 2, 1.0)
633
- print(f"🔄 AST fallback: energy={energy:.6f}, centroid={spectral_centroid:.1f}, prob={probability:.4f}")
634
  else:
635
  probability = min(energy * 50, 1.0)
636
- print(f"🔄 AST fallback (simple): energy={energy:.6f}, prob={probability:.4f}")
637
  is_speech = probability > 0.25 # Use AST threshold
638
  else:
639
  probability = 0.0
@@ -641,91 +553,39 @@ class OptimizedAST:
641
  return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
642
 
643
  try:
644
- # Cache key based on timestamp rounded to cache window
645
- cache_key = int(timestamp / self.cache_window)
646
-
647
- # Check cache first
648
- if cache_key in self.prediction_cache:
649
- cached_result = self.prediction_cache[cache_key]
650
- print(f"✅ AST: Using cached result for t={timestamp:.2f}s")
651
- # Return cached result with updated timestamp
652
- return VADResult(
653
- cached_result.probability,
654
- cached_result.is_speech,
655
- cached_result.model_name + " (cached)",
656
- time.time() - start_time,
657
- timestamp
658
- )
659
 
660
  if len(audio.shape) > 1:
661
  audio = audio.mean(axis=1)
662
- print(f"🔄 AST: Converted to mono")
663
-
664
- # CRITICAL FIX: AST uses 16kHz, but input is already at 16kHz
665
- # So we DON'T need to resample, just ensure it's float32
666
- audio = audio.astype(np.float32)
667
-
668
- # Use sliding window approach for temporal resolution
669
- window_duration = 1.0 # 1 second windows
670
- window_samples = int(window_duration * self.sample_rate)
671
-
672
- # Get window for this timestamp
673
- center_sample = int(timestamp * self.sample_rate)
674
- half_window = window_samples // 2
675
 
676
- start_idx = max(0, center_sample - half_window)
677
- end_idx = min(len(audio), start_idx + window_samples)
678
 
679
- # Adjust if at the end
680
- if end_idx == len(audio) and end_idx - start_idx < window_samples:
681
- start_idx = max(0, end_idx - window_samples)
682
-
683
- audio_for_ast = audio[start_idx:end_idx]
684
- print(f"🔄 AST: Extracted window [{start_idx}:{end_idx}], len={len(audio_for_ast)}")
685
-
686
- # For short audio, use intelligent strategy
687
  min_samples = int(1.0 * self.sample_rate) # 1 second minimum
688
  if len(audio_for_ast) < min_samples:
689
- print(f"⚠️ AST: Audio too short ({len(audio_for_ast)} samples), padding")
690
- # Pad with zeros
691
- audio_padded = np.zeros(min_samples)
692
- audio_padded[:len(audio_for_ast)] = audio_for_ast
693
- audio_for_ast = audio_padded
694
- print(f"✅ AST: Padded to {len(audio_for_ast)} samples")
695
-
696
- # Truncate if too long (AST can handle up to ~10s, but we use 1s windows)
697
- max_samples = int(1.5 * self.sample_rate)
698
- if len(audio_for_ast) > max_samples:
699
- audio_for_ast = audio_for_ast[:max_samples]
700
- print(f"✂️ AST: Truncated to {len(audio_for_ast)} samples")
701
-
702
- print(f"🔄 AST: Feature extraction...")
703
- # Feature extraction with proper AST parameters
704
  inputs = self.feature_extractor(
705
  audio_for_ast,
706
  sampling_rate=self.sample_rate, # Must be 16kHz
707
  return_tensors="pt",
708
- max_length=1024, # Proper AST context
709
- padding="max_length", # Ensure consistent length
710
- truncation=True
711
  )
712
 
713
- print(f"✅ AST: Features extracted, input_shape={[v.shape if hasattr(v, 'shape') else type(v) for v in inputs.values()]}")
714
-
715
  # Move inputs to correct device and dtype
716
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
717
  if self.device.type == 'cuda' and hasattr(self.model, 'half'):
718
  # Convert inputs to FP16 if model is in FP16
719
  inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
720
 
721
- print(f"🚀 AST: Running inference...")
722
  with torch.no_grad():
723
  outputs = self.model(**inputs)
724
  logits = outputs.logits
725
  probs = torch.sigmoid(logits)
726
 
727
- print(f"✅ AST: Inference complete, logits_shape={logits.shape}, probs_shape={probs.shape}")
728
-
729
  # Find speech-related classes with enhanced keywords
730
  label2id = self.model.config.label2id
731
  speech_indices = []
@@ -739,8 +599,6 @@ class OptimizedAST:
739
  if any(word in lbl.lower() for word in speech_keywords):
740
  speech_indices.append(idx)
741
 
742
- print(f"🔍 AST: Found {len(speech_indices)} speech-related classes")
743
-
744
  # Also identify background/noise classes for better discrimination
745
  noise_keywords = ['silence', 'white noise', 'background']
746
  noise_indices = []
@@ -758,35 +616,16 @@ class OptimizedAST:
758
  noise_prob = torch.mean(probs[0, noise_indices]).item()
759
  # Reduce speech probability if high noise/silence detected
760
  speech_prob = speech_prob * (1 - noise_prob * 0.3)
761
-
762
- print(f"📈 AST: raw_speech_prob={speech_prob:.4f}")
763
-
764
- # Adjust confidence for short audio
765
- if len(audio) < self.sample_rate * 2: # Less than 2 seconds
766
- confidence_factor = len(audio) / (self.sample_rate * 2)
767
- speech_prob = speech_prob * (0.6 + 0.4 * confidence_factor)
768
- print(f"🔧 AST: Adjusted for short audio, final_prob={speech_prob:.4f}")
769
 
770
  else:
771
  # Fallback to energy-based detection with better calibration
772
  energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
773
  speech_prob = min(energy * 20, 1.0) # Better scaling
774
- print(f"⚠️ AST: No speech classes found, using energy fallback: energy={energy:.6f}, prob={speech_prob:.4f}")
775
 
776
  # Use lower threshold specifically for AST (0.25 instead of 0.4)
777
  is_speech_ast = speech_prob > 0.25
778
  result = VADResult(float(speech_prob), is_speech_ast, self.model_name, time.time()-start_time, timestamp)
779
 
780
- print(f"✅ AST: final_prob={speech_prob:.4f}, is_speech={is_speech_ast}")
781
-
782
- # Cache the result
783
- self.prediction_cache[cache_key] = result
784
-
785
- # Clean old cache entries (keep only last 30 seconds for longer sessions)
786
- cache_keys_to_remove = [k for k in self.prediction_cache.keys() if k < cache_key - 30]
787
- for k in cache_keys_to_remove:
788
- del self.prediction_cache[k]
789
-
790
  return result
791
 
792
  except Exception as e:
@@ -798,7 +637,6 @@ class OptimizedAST:
798
  energy = np.sum(audio ** 2) / len(audio) # Normalize by length
799
  probability = min(energy * 100, 1.0) # More conservative scaling
800
  is_speech = energy > 0.001 # Lower threshold for fallback
801
- print(f"🔄 AST error fallback: energy={energy:.6f}, prob={probability:.4f}")
802
  else:
803
  probability = 0.0
804
  is_speech = False
@@ -825,18 +663,18 @@ class AudioProcessor:
825
  self.model_windows = {
826
  "Silero-VAD": 0.032, # 32ms exactly as required (512 samples)
827
  "WebRTC-VAD": 0.03, # 30ms frames (480 samples)
828
- "E-PANNs": 6.0, # 6 seconds minimum for reliable results
829
- "PANNs": 10.0, # 10 seconds for optimal performance
830
- "AST": 1.0 # Changed to 1 second for better temporal resolution
831
  }
832
 
833
- # Model-specific hop sizes for efficiency
834
  self.model_hop_sizes = {
835
  "Silero-VAD": 0.016, # 16ms hop for Silero (512 samples window)
836
  "WebRTC-VAD": 0.03, # 30ms hop for WebRTC (match frame duration)
837
- "E-PANNs": 0.1, # 100ms hop for 10 predictions/second
838
- "PANNs": 0.1, # 100ms hop for 10 predictions/second
839
- "AST": 0.1 # 100ms hop for 10 predictions/second
840
  }
841
 
842
  # Model-specific thresholds for better detection
@@ -1346,78 +1184,42 @@ class VADDemo:
1346
 
1347
  model_results = []
1348
 
1349
- # Always use sliding window approach for consistent temporal resolution
1350
- if len(processed_audio) < window_samples:
1351
- debug_info.append(f" ⚠️ Audio shorter than window ({len(processed_audio)} < {window_samples}), using sliding window with padding")
 
 
 
1352
 
1353
- # For short audio, still use sliding window but with the actual audio length
1354
- # This ensures we get the desired temporal resolution (10 predictions/second)
1355
- window_count = 0
1356
- audio_duration = len(processed_audio) / self.processor.sample_rate
1357
 
1358
- # Calculate number of windows based on hop size
1359
- num_windows = max(1, int((audio_duration - window_size) / hop_size) + 1) if audio_duration > window_size else max(1, int(audio_duration / hop_size))
 
1360
 
1361
- for i in range(0, len(processed_audio), hop_samples):
1362
- timestamp = i / self.processor.sample_rate
1363
-
1364
- # For models that need long context, we'll use the full audio padded/repeated as needed
1365
- # but report the timestamp based on the sliding window position
1366
- if window_count < 3: # Log first 3 windows
1367
- debug_info.append(f" 🔄 Window {window_count}: t={timestamp:.2f}s")
1368
-
1369
- # Special handling for different models
1370
- if model_name == 'AST':
1371
- result = self.models[model_name].predict(processed_audio, timestamp, full_audio=processed_audio)
1372
- else:
1373
- result = self.models[model_name].predict(processed_audio, timestamp)
1374
-
1375
- if window_count < 3: # Log first 3 results
1376
- debug_info.append(f" 📈 Result {window_count}: prob={result.probability:.4f}, speech={result.is_speech}")
1377
-
1378
- # Use model-specific threshold
1379
- result.is_speech = result.probability > model_threshold
1380
- vad_results.append(result)
1381
- model_results.append(result)
1382
- window_count += 1
1383
-
1384
- # Stop if we've gone past the audio length
1385
- if timestamp >= audio_duration:
1386
- break
1387
 
1388
- debug_info.append(f" 🎯 Total windows processed: {window_count}")
1389
- else:
1390
- # Audio is long enough - process in sliding windows
1391
- debug_info.append(f" ✅ Audio long enough, processing in windows")
 
1392
 
1393
- window_count = 0
1394
- for i in range(0, len(processed_audio) - window_samples + 1, hop_samples):
1395
- timestamp = i / self.processor.sample_rate
1396
-
1397
- # Extract window
1398
- start_pos = i
1399
- end_pos = min(len(processed_audio), i + window_samples)
1400
- chunk = processed_audio[start_pos:end_pos]
1401
-
1402
- if window_count < 3: # Log first 3 windows
1403
- debug_info.append(f" 🔄 Window {window_count}: t={timestamp:.2f}s, size={len(chunk)}")
1404
-
1405
- # Special handling for different models
1406
- if model_name == 'AST':
1407
- result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
1408
- else:
1409
- result = self.models[model_name].predict(chunk, timestamp)
1410
-
1411
- if window_count < 3: # Log first 3 results
1412
- debug_info.append(f" 📈 Result {window_count}: prob={result.probability:.4f}, speech={result.is_speech}")
1413
-
1414
- # Use model-specific threshold
1415
- result.is_speech = result.probability > model_threshold
1416
- vad_results.append(result)
1417
- model_results.append(result)
1418
- window_count += 1
1419
 
1420
- debug_info.append(f" 🎯 Total windows processed: {window_count}")
 
 
 
 
1421
 
1422
  # Summary for this model
1423
  if model_results:
@@ -1594,7 +1396,7 @@ def create_interface():
1594
  ---
1595
  **Models**: Silero-VAD, WebRTC-VAD, E-PANNs, PANNs, AST | **Research**: WASPAA 2025 | **Institution**: University of Surrey, CVSSP
1596
 
1597
- **Note**: Large models (PANNs: 10s, E-PANNs: 6s, AST: 6.4s) work best with longer recordings. Short clips will be processed intelligently.
1598
  """)
1599
 
1600
  return interface
 
260
  def __init__(self):
261
  self.model_name = "E-PANNs"
262
  self.sample_rate = 32000
263
+ self.win_s = 1.0 # CHANGED from 6.0 to 1.0 for better temporal resolution
264
  print(f"✅ {self.model_name} initialized")
265
 
266
  # Try to load PANNs AudioTagging as backend for E-PANNs
 
282
  if len(audio.shape) > 1:
283
  audio = audio.mean(axis=1)
284
 
285
+ # CORRECTED: Work with the chunk directly, no more extracting windows
286
+ # The audio passed is already the chunk for this timestamp
287
+ x = safe_resample(audio, 16000, self.sample_rate)
288
 
289
+ # Pad to minimum window size if needed (no repeating)
290
+ min_samples = int(self.sample_rate * self.win_s)
291
+ if len(x) < min_samples:
292
+ x = np.pad(x, (0, min_samples - len(x)), mode='constant')
293
 
294
+ # If we have PANNs AT model, use it
295
+ if self.at_model is not None:
296
+ # Run inference
297
+ clipwise_output, _ = self.at_model.inference(x[np.newaxis, :])
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ # Get speech-related classes
300
+ speech_keywords = [
301
+ 'speech', 'voice', 'talk', 'conversation', 'speaking',
302
+ 'male speech', 'female speech', 'child speech',
303
+ 'narration', 'monologue', 'speech synthesizer'
304
+ ]
305
 
306
+ speech_indices = []
307
+ for i, lbl in enumerate(labels):
308
+ if any(word in lbl.lower() for word in speech_keywords):
309
+ speech_indices.append(i)
310
+
311
+ if speech_indices:
312
+ speech_probs = clipwise_output[0, speech_indices]
313
+ speech_score = float(np.max(speech_probs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  else:
315
+ speech_score = float(np.max(clipwise_output[0]))
316
+ else:
317
+ # Fallback to spectral features
318
+ if LIBROSA_AVAILABLE:
319
+ mel_spec = librosa.feature.melspectrogram(y=x, sr=self.sample_rate, n_mels=64)
320
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
321
+ spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=x, sr=self.sample_rate))
322
 
 
 
 
 
 
 
 
 
 
 
323
  energy_score = np.clip((energy + 80) / 40, 0, 1)
324
  centroid_score = np.clip((spectral_centroid - 200) / 3000, 0, 1)
325
+ speech_score = energy_score * 0.7 + centroid_score * 0.3
326
+ else:
327
+ energy = np.sum(x ** 2) / len(x)
328
+ speech_score = min(energy * 50, 1.0)
 
 
 
 
 
 
 
 
 
 
329
 
330
  probability = np.clip(speech_score, 0, 1)
331
  is_speech = probability > 0.4
 
391
  if len(audio.shape) > 1:
392
  audio = audio.mean(axis=1)
393
 
394
+ # CORRECTED: Work with the chunk directly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  # Convert audio to PANNs sample rate
396
  if LIBROSA_AVAILABLE:
397
+ audio_resampled = librosa.resample(audio.astype(float),
398
  orig_sr=16000,
399
  target_sr=self.sample_rate)
400
  else:
401
  # Simple resampling fallback
402
  resample_factor = self.sample_rate / 16000
403
  audio_resampled = np.interp(
404
+ np.linspace(0, len(audio) - 1, int(len(audio) * resample_factor)),
405
+ np.arange(len(audio)),
406
+ audio
407
  )
408
 
409
+ # For short audio, pad (no repeating)
410
+ min_samples = 1 * self.sample_rate # 1 second minimum
411
  if len(audio_resampled) < min_samples:
412
+ audio_resampled = np.pad(audio_resampled, (0, min_samples - len(audio_resampled)), mode='constant')
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
  # Use SED for framewise predictions if available
415
  if self.sed_model is not None:
 
422
  if framewise_output.ndim == 3:
423
  framewise_output = framewise_output[0] # Remove batch dimension
424
 
425
+ # Get middle frame (corresponding to center of window)
426
+ frame_idx = framewise_output.shape[0] // 2
 
 
 
 
 
427
 
428
  # Get speech-related classes
429
  speech_keywords = [
 
476
  noise_prob = np.mean(clip_probs[0, noise_indices])
477
  # Adjust speech probability based on noise
478
  speech_prob = speech_prob * (1 - noise_prob * 0.5)
 
 
 
 
 
479
 
480
  else:
481
  # Fallback if no speech indices found
 
499
  return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
500
 
501
  class OptimizedAST:
502
+ """CORRECTED AST with proper 16kHz sample rate and NO CACHE"""
503
  def __init__(self):
504
  self.model_name = "AST"
505
  self.sample_rate = 16000 # AST REQUIRES 16kHz
506
  self.model = None
507
  self.feature_extractor = None
508
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
509
+ # NO CACHE - removed cache_window and prediction_cache
 
510
  self.load_model()
511
 
512
  def load_model(self):
 
535
  def predict(self, audio: np.ndarray, timestamp: float = 0.0, full_audio: np.ndarray = None) -> VADResult:
536
  start_time = time.time()
537
 
 
 
 
 
538
  if self.model is None or len(audio) == 0:
 
539
  # Enhanced fallback using spectral features
540
  if len(audio) > 0:
541
  energy = np.sum(audio ** 2)
 
544
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
545
  # Combine multiple features for better speech detection
546
  probability = min((energy * 100 + spectral_centroid / 1000) / 2, 1.0)
 
547
  else:
548
  probability = min(energy * 50, 1.0)
 
549
  is_speech = probability > 0.25 # Use AST threshold
550
  else:
551
  probability = 0.0
 
553
  return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
554
 
555
  try:
556
+ # NO CACHE - removed all cache-related code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
 
558
  if len(audio.shape) > 1:
559
  audio = audio.mean(axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
561
+ # CRITICAL: AST uses 16kHz, input is already at 16kHz
562
+ audio_for_ast = audio.astype(np.float32)
563
 
564
+ # Pad to minimum 1 second if needed
 
 
 
 
 
 
 
565
  min_samples = int(1.0 * self.sample_rate) # 1 second minimum
566
  if len(audio_for_ast) < min_samples:
567
+ audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), mode='constant')
568
+
569
+ # Feature extraction with NO PADDING to 1024
 
 
 
 
 
 
 
 
 
 
 
 
570
  inputs = self.feature_extractor(
571
  audio_for_ast,
572
  sampling_rate=self.sample_rate, # Must be 16kHz
573
  return_tensors="pt",
574
+ padding=False, # CHANGED: No padding to 1024
575
+ truncation=False # CHANGED: No truncation
 
576
  )
577
 
 
 
578
  # Move inputs to correct device and dtype
579
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
580
  if self.device.type == 'cuda' and hasattr(self.model, 'half'):
581
  # Convert inputs to FP16 if model is in FP16
582
  inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
583
 
 
584
  with torch.no_grad():
585
  outputs = self.model(**inputs)
586
  logits = outputs.logits
587
  probs = torch.sigmoid(logits)
588
 
 
 
589
  # Find speech-related classes with enhanced keywords
590
  label2id = self.model.config.label2id
591
  speech_indices = []
 
599
  if any(word in lbl.lower() for word in speech_keywords):
600
  speech_indices.append(idx)
601
 
 
 
602
  # Also identify background/noise classes for better discrimination
603
  noise_keywords = ['silence', 'white noise', 'background']
604
  noise_indices = []
 
616
  noise_prob = torch.mean(probs[0, noise_indices]).item()
617
  # Reduce speech probability if high noise/silence detected
618
  speech_prob = speech_prob * (1 - noise_prob * 0.3)
 
 
 
 
 
 
 
 
619
 
620
  else:
621
  # Fallback to energy-based detection with better calibration
622
  energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
623
  speech_prob = min(energy * 20, 1.0) # Better scaling
 
624
 
625
  # Use lower threshold specifically for AST (0.25 instead of 0.4)
626
  is_speech_ast = speech_prob > 0.25
627
  result = VADResult(float(speech_prob), is_speech_ast, self.model_name, time.time()-start_time, timestamp)
628
 
 
 
 
 
 
 
 
 
 
 
629
  return result
630
 
631
  except Exception as e:
 
637
  energy = np.sum(audio ** 2) / len(audio) # Normalize by length
638
  probability = min(energy * 100, 1.0) # More conservative scaling
639
  is_speech = energy > 0.001 # Lower threshold for fallback
 
640
  else:
641
  probability = 0.0
642
  is_speech = False
 
663
  self.model_windows = {
664
  "Silero-VAD": 0.032, # 32ms exactly as required (512 samples)
665
  "WebRTC-VAD": 0.03, # 30ms frames (480 samples)
666
+ "E-PANNs": 1.0, # CHANGED from 6.0 to 1.0 for better temporal resolution
667
+ "PANNs": 1.0, # CHANGED from 10.0 to 1.0 for better temporal resolution
668
+ "AST": 1.0 # 1 second for better temporal resolution
669
  }
670
 
671
+ # Model-specific hop sizes for efficiency - INCREASED to 20Hz
672
  self.model_hop_sizes = {
673
  "Silero-VAD": 0.016, # 16ms hop for Silero (512 samples window)
674
  "WebRTC-VAD": 0.03, # 30ms hop for WebRTC (match frame duration)
675
+ "E-PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
676
+ "PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
677
+ "AST": 0.05 # CHANGED from 0.1 to 0.05 for 20Hz
678
  }
679
 
680
  # Model-specific thresholds for better detection
 
1184
 
1185
  model_results = []
1186
 
1187
+ # CRITICAL FIX: Always extract chunks, both for short and long audio
1188
+ window_count = 0
1189
+ audio_duration = len(processed_audio) / self.processor.sample_rate
1190
+
1191
+ for i in range(0, len(processed_audio), hop_samples):
1192
+ timestamp = i / self.processor.sample_rate
1193
 
1194
+ # CRITICAL: Extract the chunk centered on this timestamp
1195
+ start_pos = max(0, i - window_samples // 2)
1196
+ end_pos = min(len(processed_audio), start_pos + window_samples)
1197
+ chunk = processed_audio[start_pos:end_pos]
1198
 
1199
+ # Pad if necessary (with zeros, not repeating)
1200
+ if len(chunk) < window_samples:
1201
+ chunk = np.pad(chunk, (0, window_samples - len(chunk)), mode='constant')
1202
 
1203
+ if window_count < 3: # Log first 3 windows
1204
+ debug_info.append(f" 🔄 Window {window_count}: t={timestamp:.2f}s, chunk_size={len(chunk)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1205
 
1206
+ # Call predict with the chunk
1207
+ result = self.models[model_name].predict(chunk, timestamp)
1208
+
1209
+ if window_count < 3: # Log first 3 results
1210
+ debug_info.append(f" 📈 Result {window_count}: prob={result.probability:.4f}, speech={result.is_speech}")
1211
 
1212
+ # Use model-specific threshold
1213
+ result.is_speech = result.probability > model_threshold
1214
+ vad_results.append(result)
1215
+ model_results.append(result)
1216
+ window_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1217
 
1218
+ # Stop if we've gone past the audio length
1219
+ if timestamp >= audio_duration:
1220
+ break
1221
+
1222
+ debug_info.append(f" 🎯 Total windows processed: {window_count}")
1223
 
1224
  # Summary for this model
1225
  if model_results:
 
1396
  ---
1397
  **Models**: Silero-VAD, WebRTC-VAD, E-PANNs, PANNs, AST | **Research**: WASPAA 2025 | **Institution**: University of Surrey, CVSSP
1398
 
1399
+ **Note**: All models now provide high temporal resolution (20Hz) for accurate real-time speech detection.
1400
  """)
1401
 
1402
  return interface