Gabriel Bibbó
commited on
Commit
·
aee7b20
1
Parent(s):
d02d086
adjust app.py
Browse files
app.py
CHANGED
|
@@ -260,6 +260,7 @@ class OptimizedEPANNs:
|
|
| 260 |
def __init__(self):
|
| 261 |
self.model_name = "E-PANNs"
|
| 262 |
self.sample_rate = 32000
|
|
|
|
| 263 |
print(f"✅ {self.model_name} initialized")
|
| 264 |
|
| 265 |
# Try to load PANNs AudioTagging as backend for E-PANNs
|
|
@@ -281,92 +282,50 @@ class OptimizedEPANNs:
|
|
| 281 |
if len(audio.shape) > 1:
|
| 282 |
audio = audio.mean(axis=1)
|
| 283 |
|
| 284 |
-
#
|
| 285 |
-
|
| 286 |
-
|
| 287 |
|
| 288 |
-
#
|
| 289 |
-
|
| 290 |
-
|
|
|
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
# Adjust start if we're at the end of audio
|
| 297 |
-
if end_idx == len(audio) and end_idx - start_idx < window_samples:
|
| 298 |
-
start_idx = max(0, end_idx - window_samples)
|
| 299 |
-
|
| 300 |
-
audio_window = audio[start_idx:end_idx]
|
| 301 |
-
|
| 302 |
-
# Convert audio to target sample rate for E-PANNs (32kHz)
|
| 303 |
-
if LIBROSA_AVAILABLE:
|
| 304 |
-
# Resample to E-PANNs sample rate
|
| 305 |
-
audio_resampled = librosa.resample(audio_window.astype(float),
|
| 306 |
-
orig_sr=16000,
|
| 307 |
-
target_sr=self.sample_rate)
|
| 308 |
|
| 309 |
-
#
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
'male speech', 'female speech', 'child speech',
|
| 325 |
-
'narration', 'monologue'
|
| 326 |
-
]
|
| 327 |
-
|
| 328 |
-
speech_indices = []
|
| 329 |
-
for i, lbl in enumerate(labels):
|
| 330 |
-
if any(word in lbl.lower() for word in speech_keywords):
|
| 331 |
-
speech_indices.append(i)
|
| 332 |
-
|
| 333 |
-
if speech_indices:
|
| 334 |
-
speech_probs = clipwise_output[0, speech_indices]
|
| 335 |
-
speech_score = float(np.max(speech_probs))
|
| 336 |
-
else:
|
| 337 |
-
speech_score = float(np.max(clipwise_output[0]))
|
| 338 |
else:
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
|
|
|
|
|
|
| 342 |
energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
|
|
|
|
| 343 |
|
| 344 |
-
# Use actual non-repeated audio for some features
|
| 345 |
-
actual_audio_len = min(len(audio_resampled), int(len(audio_window) * self.sample_rate / 16000))
|
| 346 |
-
actual_audio = audio_resampled[:actual_audio_len]
|
| 347 |
-
|
| 348 |
-
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=actual_audio, sr=self.sample_rate))
|
| 349 |
-
mfcc = librosa.feature.mfcc(y=actual_audio, sr=self.sample_rate, n_mfcc=13)
|
| 350 |
-
mfcc_var = np.var(mfcc, axis=1).mean()
|
| 351 |
-
zcr = np.mean(librosa.feature.zero_crossing_rate(actual_audio))
|
| 352 |
-
|
| 353 |
-
# Adjusted scaling for better speech detection
|
| 354 |
energy_score = np.clip((energy + 80) / 40, 0, 1)
|
| 355 |
centroid_score = np.clip((spectral_centroid - 200) / 3000, 0, 1)
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
speech_score = (energy_score * 0.4 +
|
| 361 |
-
centroid_score * 0.2 +
|
| 362 |
-
mfcc_score * 0.3 +
|
| 363 |
-
zcr_score * 0.1)
|
| 364 |
-
else:
|
| 365 |
-
from scipy import signal
|
| 366 |
-
# Basic fallback without librosa
|
| 367 |
-
f, t, Sxx = signal.spectrogram(audio_window, 16000)
|
| 368 |
-
energy = np.mean(10 * np.log10(Sxx + 1e-10))
|
| 369 |
-
speech_score = np.clip((energy + 100) / 50, 0, 1)
|
| 370 |
|
| 371 |
probability = np.clip(speech_score, 0, 1)
|
| 372 |
is_speech = probability > 0.4
|
|
@@ -432,54 +391,25 @@ class OptimizedPANNs:
|
|
| 432 |
if len(audio.shape) > 1:
|
| 433 |
audio = audio.mean(axis=1)
|
| 434 |
|
| 435 |
-
#
|
| 436 |
-
window_duration = 10.0 # 10 seconds window for PANNs
|
| 437 |
-
window_samples = int(window_duration * 16000) # at 16kHz input rate
|
| 438 |
-
|
| 439 |
-
# Calculate the center position for this timestamp
|
| 440 |
-
center_sample = int(timestamp * 16000)
|
| 441 |
-
half_window = window_samples // 2
|
| 442 |
-
|
| 443 |
-
# Extract window centered at timestamp
|
| 444 |
-
start_idx = max(0, center_sample - half_window)
|
| 445 |
-
end_idx = min(len(audio), start_idx + window_samples)
|
| 446 |
-
|
| 447 |
-
# Adjust start if we're at the end of audio
|
| 448 |
-
if end_idx == len(audio) and end_idx - start_idx < window_samples:
|
| 449 |
-
start_idx = max(0, end_idx - window_samples)
|
| 450 |
-
|
| 451 |
-
audio_window = audio[start_idx:end_idx]
|
| 452 |
-
|
| 453 |
# Convert audio to PANNs sample rate
|
| 454 |
if LIBROSA_AVAILABLE:
|
| 455 |
-
audio_resampled = librosa.resample(
|
| 456 |
orig_sr=16000,
|
| 457 |
target_sr=self.sample_rate)
|
| 458 |
else:
|
| 459 |
# Simple resampling fallback
|
| 460 |
resample_factor = self.sample_rate / 16000
|
| 461 |
audio_resampled = np.interp(
|
| 462 |
-
np.linspace(0, len(
|
| 463 |
-
np.arange(len(
|
| 464 |
-
|
| 465 |
)
|
| 466 |
|
| 467 |
-
# For short audio,
|
| 468 |
-
min_samples =
|
| 469 |
if len(audio_resampled) < min_samples:
|
| 470 |
-
|
| 471 |
-
num_repeats = int(np.ceil(min_samples / len(audio_resampled)))
|
| 472 |
-
audio_repeated = np.tile(audio_resampled, num_repeats)[:min_samples]
|
| 473 |
-
|
| 474 |
-
# Apply fade in/out to reduce artifacts
|
| 475 |
-
fade_len = int(0.1 * self.sample_rate) # 100ms fade
|
| 476 |
-
fade_in = np.linspace(0, 1, fade_len)
|
| 477 |
-
fade_out = np.linspace(1, 0, fade_len)
|
| 478 |
-
|
| 479 |
-
audio_repeated[:fade_len] *= fade_in
|
| 480 |
-
audio_repeated[-fade_len:] *= fade_out
|
| 481 |
-
|
| 482 |
-
audio_resampled = audio_repeated
|
| 483 |
|
| 484 |
# Use SED for framewise predictions if available
|
| 485 |
if self.sed_model is not None:
|
|
@@ -492,13 +422,8 @@ class OptimizedPANNs:
|
|
| 492 |
if framewise_output.ndim == 3:
|
| 493 |
framewise_output = framewise_output[0] # Remove batch dimension
|
| 494 |
|
| 495 |
-
# Get frame corresponding to
|
| 496 |
-
|
| 497 |
-
if audio_duration > 0:
|
| 498 |
-
frame_idx = int((timestamp % audio_duration) / audio_duration * framewise_output.shape[0])
|
| 499 |
-
frame_idx = min(frame_idx, framewise_output.shape[0] - 1)
|
| 500 |
-
else:
|
| 501 |
-
frame_idx = 0
|
| 502 |
|
| 503 |
# Get speech-related classes
|
| 504 |
speech_keywords = [
|
|
@@ -551,11 +476,6 @@ class OptimizedPANNs:
|
|
| 551 |
noise_prob = np.mean(clip_probs[0, noise_indices])
|
| 552 |
# Adjust speech probability based on noise
|
| 553 |
speech_prob = speech_prob * (1 - noise_prob * 0.5)
|
| 554 |
-
|
| 555 |
-
# If using repeated audio, scale confidence based on original length
|
| 556 |
-
if len(audio_window) < 16000 * 2: # Less than 2 seconds
|
| 557 |
-
confidence_scale = len(audio_window) / (16000 * 2)
|
| 558 |
-
speech_prob = speech_prob * (0.5 + 0.5 * confidence_scale)
|
| 559 |
|
| 560 |
else:
|
| 561 |
# Fallback if no speech indices found
|
|
@@ -579,15 +499,14 @@ class OptimizedPANNs:
|
|
| 579 |
return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
|
| 580 |
|
| 581 |
class OptimizedAST:
|
| 582 |
-
"""CORRECTED AST with proper 16kHz sample rate and
|
| 583 |
def __init__(self):
|
| 584 |
self.model_name = "AST"
|
| 585 |
self.sample_rate = 16000 # AST REQUIRES 16kHz
|
| 586 |
self.model = None
|
| 587 |
self.feature_extractor = None
|
| 588 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 589 |
-
|
| 590 |
-
self.cache_window = 1.0 # Cachear resultados por segundo
|
| 591 |
self.load_model()
|
| 592 |
|
| 593 |
def load_model(self):
|
|
@@ -616,12 +535,7 @@ class OptimizedAST:
|
|
| 616 |
def predict(self, audio: np.ndarray, timestamp: float = 0.0, full_audio: np.ndarray = None) -> VADResult:
|
| 617 |
start_time = time.time()
|
| 618 |
|
| 619 |
-
print(f"🔍 AST predict: audio_len={len(audio)}, timestamp={timestamp:.2f}, model_available={self.model is not None}")
|
| 620 |
-
if full_audio is not None:
|
| 621 |
-
print(f"🔍 AST: full_audio_len={len(full_audio)}")
|
| 622 |
-
|
| 623 |
if self.model is None or len(audio) == 0:
|
| 624 |
-
print(f"❌ AST: Model unavailable or empty audio")
|
| 625 |
# Enhanced fallback using spectral features
|
| 626 |
if len(audio) > 0:
|
| 627 |
energy = np.sum(audio ** 2)
|
|
@@ -630,10 +544,8 @@ class OptimizedAST:
|
|
| 630 |
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
|
| 631 |
# Combine multiple features for better speech detection
|
| 632 |
probability = min((energy * 100 + spectral_centroid / 1000) / 2, 1.0)
|
| 633 |
-
print(f"🔄 AST fallback: energy={energy:.6f}, centroid={spectral_centroid:.1f}, prob={probability:.4f}")
|
| 634 |
else:
|
| 635 |
probability = min(energy * 50, 1.0)
|
| 636 |
-
print(f"🔄 AST fallback (simple): energy={energy:.6f}, prob={probability:.4f}")
|
| 637 |
is_speech = probability > 0.25 # Use AST threshold
|
| 638 |
else:
|
| 639 |
probability = 0.0
|
|
@@ -641,91 +553,39 @@ class OptimizedAST:
|
|
| 641 |
return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
|
| 642 |
|
| 643 |
try:
|
| 644 |
-
#
|
| 645 |
-
cache_key = int(timestamp / self.cache_window)
|
| 646 |
-
|
| 647 |
-
# Check cache first
|
| 648 |
-
if cache_key in self.prediction_cache:
|
| 649 |
-
cached_result = self.prediction_cache[cache_key]
|
| 650 |
-
print(f"✅ AST: Using cached result for t={timestamp:.2f}s")
|
| 651 |
-
# Return cached result with updated timestamp
|
| 652 |
-
return VADResult(
|
| 653 |
-
cached_result.probability,
|
| 654 |
-
cached_result.is_speech,
|
| 655 |
-
cached_result.model_name + " (cached)",
|
| 656 |
-
time.time() - start_time,
|
| 657 |
-
timestamp
|
| 658 |
-
)
|
| 659 |
|
| 660 |
if len(audio.shape) > 1:
|
| 661 |
audio = audio.mean(axis=1)
|
| 662 |
-
print(f"🔄 AST: Converted to mono")
|
| 663 |
-
|
| 664 |
-
# CRITICAL FIX: AST uses 16kHz, but input is already at 16kHz
|
| 665 |
-
# So we DON'T need to resample, just ensure it's float32
|
| 666 |
-
audio = audio.astype(np.float32)
|
| 667 |
-
|
| 668 |
-
# Use sliding window approach for temporal resolution
|
| 669 |
-
window_duration = 1.0 # 1 second windows
|
| 670 |
-
window_samples = int(window_duration * self.sample_rate)
|
| 671 |
-
|
| 672 |
-
# Get window for this timestamp
|
| 673 |
-
center_sample = int(timestamp * self.sample_rate)
|
| 674 |
-
half_window = window_samples // 2
|
| 675 |
|
| 676 |
-
|
| 677 |
-
|
| 678 |
|
| 679 |
-
#
|
| 680 |
-
if end_idx == len(audio) and end_idx - start_idx < window_samples:
|
| 681 |
-
start_idx = max(0, end_idx - window_samples)
|
| 682 |
-
|
| 683 |
-
audio_for_ast = audio[start_idx:end_idx]
|
| 684 |
-
print(f"🔄 AST: Extracted window [{start_idx}:{end_idx}], len={len(audio_for_ast)}")
|
| 685 |
-
|
| 686 |
-
# For short audio, use intelligent strategy
|
| 687 |
min_samples = int(1.0 * self.sample_rate) # 1 second minimum
|
| 688 |
if len(audio_for_ast) < min_samples:
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
audio_padded[:len(audio_for_ast)] = audio_for_ast
|
| 693 |
-
audio_for_ast = audio_padded
|
| 694 |
-
print(f"✅ AST: Padded to {len(audio_for_ast)} samples")
|
| 695 |
-
|
| 696 |
-
# Truncate if too long (AST can handle up to ~10s, but we use 1s windows)
|
| 697 |
-
max_samples = int(1.5 * self.sample_rate)
|
| 698 |
-
if len(audio_for_ast) > max_samples:
|
| 699 |
-
audio_for_ast = audio_for_ast[:max_samples]
|
| 700 |
-
print(f"✂️ AST: Truncated to {len(audio_for_ast)} samples")
|
| 701 |
-
|
| 702 |
-
print(f"🔄 AST: Feature extraction...")
|
| 703 |
-
# Feature extraction with proper AST parameters
|
| 704 |
inputs = self.feature_extractor(
|
| 705 |
audio_for_ast,
|
| 706 |
sampling_rate=self.sample_rate, # Must be 16kHz
|
| 707 |
return_tensors="pt",
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
truncation=True
|
| 711 |
)
|
| 712 |
|
| 713 |
-
print(f"✅ AST: Features extracted, input_shape={[v.shape if hasattr(v, 'shape') else type(v) for v in inputs.values()]}")
|
| 714 |
-
|
| 715 |
# Move inputs to correct device and dtype
|
| 716 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 717 |
if self.device.type == 'cuda' and hasattr(self.model, 'half'):
|
| 718 |
# Convert inputs to FP16 if model is in FP16
|
| 719 |
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
|
| 720 |
|
| 721 |
-
print(f"🚀 AST: Running inference...")
|
| 722 |
with torch.no_grad():
|
| 723 |
outputs = self.model(**inputs)
|
| 724 |
logits = outputs.logits
|
| 725 |
probs = torch.sigmoid(logits)
|
| 726 |
|
| 727 |
-
print(f"✅ AST: Inference complete, logits_shape={logits.shape}, probs_shape={probs.shape}")
|
| 728 |
-
|
| 729 |
# Find speech-related classes with enhanced keywords
|
| 730 |
label2id = self.model.config.label2id
|
| 731 |
speech_indices = []
|
|
@@ -739,8 +599,6 @@ class OptimizedAST:
|
|
| 739 |
if any(word in lbl.lower() for word in speech_keywords):
|
| 740 |
speech_indices.append(idx)
|
| 741 |
|
| 742 |
-
print(f"🔍 AST: Found {len(speech_indices)} speech-related classes")
|
| 743 |
-
|
| 744 |
# Also identify background/noise classes for better discrimination
|
| 745 |
noise_keywords = ['silence', 'white noise', 'background']
|
| 746 |
noise_indices = []
|
|
@@ -758,35 +616,16 @@ class OptimizedAST:
|
|
| 758 |
noise_prob = torch.mean(probs[0, noise_indices]).item()
|
| 759 |
# Reduce speech probability if high noise/silence detected
|
| 760 |
speech_prob = speech_prob * (1 - noise_prob * 0.3)
|
| 761 |
-
|
| 762 |
-
print(f"📈 AST: raw_speech_prob={speech_prob:.4f}")
|
| 763 |
-
|
| 764 |
-
# Adjust confidence for short audio
|
| 765 |
-
if len(audio) < self.sample_rate * 2: # Less than 2 seconds
|
| 766 |
-
confidence_factor = len(audio) / (self.sample_rate * 2)
|
| 767 |
-
speech_prob = speech_prob * (0.6 + 0.4 * confidence_factor)
|
| 768 |
-
print(f"🔧 AST: Adjusted for short audio, final_prob={speech_prob:.4f}")
|
| 769 |
|
| 770 |
else:
|
| 771 |
# Fallback to energy-based detection with better calibration
|
| 772 |
energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
|
| 773 |
speech_prob = min(energy * 20, 1.0) # Better scaling
|
| 774 |
-
print(f"⚠️ AST: No speech classes found, using energy fallback: energy={energy:.6f}, prob={speech_prob:.4f}")
|
| 775 |
|
| 776 |
# Use lower threshold specifically for AST (0.25 instead of 0.4)
|
| 777 |
is_speech_ast = speech_prob > 0.25
|
| 778 |
result = VADResult(float(speech_prob), is_speech_ast, self.model_name, time.time()-start_time, timestamp)
|
| 779 |
|
| 780 |
-
print(f"✅ AST: final_prob={speech_prob:.4f}, is_speech={is_speech_ast}")
|
| 781 |
-
|
| 782 |
-
# Cache the result
|
| 783 |
-
self.prediction_cache[cache_key] = result
|
| 784 |
-
|
| 785 |
-
# Clean old cache entries (keep only last 30 seconds for longer sessions)
|
| 786 |
-
cache_keys_to_remove = [k for k in self.prediction_cache.keys() if k < cache_key - 30]
|
| 787 |
-
for k in cache_keys_to_remove:
|
| 788 |
-
del self.prediction_cache[k]
|
| 789 |
-
|
| 790 |
return result
|
| 791 |
|
| 792 |
except Exception as e:
|
|
@@ -798,7 +637,6 @@ class OptimizedAST:
|
|
| 798 |
energy = np.sum(audio ** 2) / len(audio) # Normalize by length
|
| 799 |
probability = min(energy * 100, 1.0) # More conservative scaling
|
| 800 |
is_speech = energy > 0.001 # Lower threshold for fallback
|
| 801 |
-
print(f"🔄 AST error fallback: energy={energy:.6f}, prob={probability:.4f}")
|
| 802 |
else:
|
| 803 |
probability = 0.0
|
| 804 |
is_speech = False
|
|
@@ -825,18 +663,18 @@ class AudioProcessor:
|
|
| 825 |
self.model_windows = {
|
| 826 |
"Silero-VAD": 0.032, # 32ms exactly as required (512 samples)
|
| 827 |
"WebRTC-VAD": 0.03, # 30ms frames (480 samples)
|
| 828 |
-
"E-PANNs":
|
| 829 |
-
"PANNs":
|
| 830 |
-
"AST": 1.0 #
|
| 831 |
}
|
| 832 |
|
| 833 |
-
# Model-specific hop sizes for efficiency
|
| 834 |
self.model_hop_sizes = {
|
| 835 |
"Silero-VAD": 0.016, # 16ms hop for Silero (512 samples window)
|
| 836 |
"WebRTC-VAD": 0.03, # 30ms hop for WebRTC (match frame duration)
|
| 837 |
-
"E-PANNs": 0.
|
| 838 |
-
"PANNs": 0.
|
| 839 |
-
"AST": 0.
|
| 840 |
}
|
| 841 |
|
| 842 |
# Model-specific thresholds for better detection
|
|
@@ -1346,78 +1184,42 @@ class VADDemo:
|
|
| 1346 |
|
| 1347 |
model_results = []
|
| 1348 |
|
| 1349 |
-
# Always
|
| 1350 |
-
|
| 1351 |
-
|
|
|
|
|
|
|
|
|
|
| 1352 |
|
| 1353 |
-
#
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
|
| 1358 |
-
#
|
| 1359 |
-
|
|
|
|
| 1360 |
|
| 1361 |
-
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
# For models that need long context, we'll use the full audio padded/repeated as needed
|
| 1365 |
-
# but report the timestamp based on the sliding window position
|
| 1366 |
-
if window_count < 3: # Log first 3 windows
|
| 1367 |
-
debug_info.append(f" 🔄 Window {window_count}: t={timestamp:.2f}s")
|
| 1368 |
-
|
| 1369 |
-
# Special handling for different models
|
| 1370 |
-
if model_name == 'AST':
|
| 1371 |
-
result = self.models[model_name].predict(processed_audio, timestamp, full_audio=processed_audio)
|
| 1372 |
-
else:
|
| 1373 |
-
result = self.models[model_name].predict(processed_audio, timestamp)
|
| 1374 |
-
|
| 1375 |
-
if window_count < 3: # Log first 3 results
|
| 1376 |
-
debug_info.append(f" 📈 Result {window_count}: prob={result.probability:.4f}, speech={result.is_speech}")
|
| 1377 |
-
|
| 1378 |
-
# Use model-specific threshold
|
| 1379 |
-
result.is_speech = result.probability > model_threshold
|
| 1380 |
-
vad_results.append(result)
|
| 1381 |
-
model_results.append(result)
|
| 1382 |
-
window_count += 1
|
| 1383 |
-
|
| 1384 |
-
# Stop if we've gone past the audio length
|
| 1385 |
-
if timestamp >= audio_duration:
|
| 1386 |
-
break
|
| 1387 |
|
| 1388 |
-
|
| 1389 |
-
|
| 1390 |
-
|
| 1391 |
-
|
|
|
|
| 1392 |
|
| 1393 |
-
|
| 1394 |
-
|
| 1395 |
-
|
| 1396 |
-
|
| 1397 |
-
|
| 1398 |
-
start_pos = i
|
| 1399 |
-
end_pos = min(len(processed_audio), i + window_samples)
|
| 1400 |
-
chunk = processed_audio[start_pos:end_pos]
|
| 1401 |
-
|
| 1402 |
-
if window_count < 3: # Log first 3 windows
|
| 1403 |
-
debug_info.append(f" 🔄 Window {window_count}: t={timestamp:.2f}s, size={len(chunk)}")
|
| 1404 |
-
|
| 1405 |
-
# Special handling for different models
|
| 1406 |
-
if model_name == 'AST':
|
| 1407 |
-
result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
|
| 1408 |
-
else:
|
| 1409 |
-
result = self.models[model_name].predict(chunk, timestamp)
|
| 1410 |
-
|
| 1411 |
-
if window_count < 3: # Log first 3 results
|
| 1412 |
-
debug_info.append(f" 📈 Result {window_count}: prob={result.probability:.4f}, speech={result.is_speech}")
|
| 1413 |
-
|
| 1414 |
-
# Use model-specific threshold
|
| 1415 |
-
result.is_speech = result.probability > model_threshold
|
| 1416 |
-
vad_results.append(result)
|
| 1417 |
-
model_results.append(result)
|
| 1418 |
-
window_count += 1
|
| 1419 |
|
| 1420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1421 |
|
| 1422 |
# Summary for this model
|
| 1423 |
if model_results:
|
|
@@ -1594,7 +1396,7 @@ def create_interface():
|
|
| 1594 |
---
|
| 1595 |
**Models**: Silero-VAD, WebRTC-VAD, E-PANNs, PANNs, AST | **Research**: WASPAA 2025 | **Institution**: University of Surrey, CVSSP
|
| 1596 |
|
| 1597 |
-
**Note**:
|
| 1598 |
""")
|
| 1599 |
|
| 1600 |
return interface
|
|
|
|
| 260 |
def __init__(self):
|
| 261 |
self.model_name = "E-PANNs"
|
| 262 |
self.sample_rate = 32000
|
| 263 |
+
self.win_s = 1.0 # CHANGED from 6.0 to 1.0 for better temporal resolution
|
| 264 |
print(f"✅ {self.model_name} initialized")
|
| 265 |
|
| 266 |
# Try to load PANNs AudioTagging as backend for E-PANNs
|
|
|
|
| 282 |
if len(audio.shape) > 1:
|
| 283 |
audio = audio.mean(axis=1)
|
| 284 |
|
| 285 |
+
# CORRECTED: Work with the chunk directly, no more extracting windows
|
| 286 |
+
# The audio passed is already the chunk for this timestamp
|
| 287 |
+
x = safe_resample(audio, 16000, self.sample_rate)
|
| 288 |
|
| 289 |
+
# Pad to minimum window size if needed (no repeating)
|
| 290 |
+
min_samples = int(self.sample_rate * self.win_s)
|
| 291 |
+
if len(x) < min_samples:
|
| 292 |
+
x = np.pad(x, (0, min_samples - len(x)), mode='constant')
|
| 293 |
|
| 294 |
+
# If we have PANNs AT model, use it
|
| 295 |
+
if self.at_model is not None:
|
| 296 |
+
# Run inference
|
| 297 |
+
clipwise_output, _ = self.at_model.inference(x[np.newaxis, :])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
+
# Get speech-related classes
|
| 300 |
+
speech_keywords = [
|
| 301 |
+
'speech', 'voice', 'talk', 'conversation', 'speaking',
|
| 302 |
+
'male speech', 'female speech', 'child speech',
|
| 303 |
+
'narration', 'monologue', 'speech synthesizer'
|
| 304 |
+
]
|
| 305 |
|
| 306 |
+
speech_indices = []
|
| 307 |
+
for i, lbl in enumerate(labels):
|
| 308 |
+
if any(word in lbl.lower() for word in speech_keywords):
|
| 309 |
+
speech_indices.append(i)
|
| 310 |
+
|
| 311 |
+
if speech_indices:
|
| 312 |
+
speech_probs = clipwise_output[0, speech_indices]
|
| 313 |
+
speech_score = float(np.max(speech_probs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
else:
|
| 315 |
+
speech_score = float(np.max(clipwise_output[0]))
|
| 316 |
+
else:
|
| 317 |
+
# Fallback to spectral features
|
| 318 |
+
if LIBROSA_AVAILABLE:
|
| 319 |
+
mel_spec = librosa.feature.melspectrogram(y=x, sr=self.sample_rate, n_mels=64)
|
| 320 |
energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
|
| 321 |
+
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=x, sr=self.sample_rate))
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
energy_score = np.clip((energy + 80) / 40, 0, 1)
|
| 324 |
centroid_score = np.clip((spectral_centroid - 200) / 3000, 0, 1)
|
| 325 |
+
speech_score = energy_score * 0.7 + centroid_score * 0.3
|
| 326 |
+
else:
|
| 327 |
+
energy = np.sum(x ** 2) / len(x)
|
| 328 |
+
speech_score = min(energy * 50, 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
probability = np.clip(speech_score, 0, 1)
|
| 331 |
is_speech = probability > 0.4
|
|
|
|
| 391 |
if len(audio.shape) > 1:
|
| 392 |
audio = audio.mean(axis=1)
|
| 393 |
|
| 394 |
+
# CORRECTED: Work with the chunk directly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
# Convert audio to PANNs sample rate
|
| 396 |
if LIBROSA_AVAILABLE:
|
| 397 |
+
audio_resampled = librosa.resample(audio.astype(float),
|
| 398 |
orig_sr=16000,
|
| 399 |
target_sr=self.sample_rate)
|
| 400 |
else:
|
| 401 |
# Simple resampling fallback
|
| 402 |
resample_factor = self.sample_rate / 16000
|
| 403 |
audio_resampled = np.interp(
|
| 404 |
+
np.linspace(0, len(audio) - 1, int(len(audio) * resample_factor)),
|
| 405 |
+
np.arange(len(audio)),
|
| 406 |
+
audio
|
| 407 |
)
|
| 408 |
|
| 409 |
+
# For short audio, pad (no repeating)
|
| 410 |
+
min_samples = 1 * self.sample_rate # 1 second minimum
|
| 411 |
if len(audio_resampled) < min_samples:
|
| 412 |
+
audio_resampled = np.pad(audio_resampled, (0, min_samples - len(audio_resampled)), mode='constant')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
# Use SED for framewise predictions if available
|
| 415 |
if self.sed_model is not None:
|
|
|
|
| 422 |
if framewise_output.ndim == 3:
|
| 423 |
framewise_output = framewise_output[0] # Remove batch dimension
|
| 424 |
|
| 425 |
+
# Get middle frame (corresponding to center of window)
|
| 426 |
+
frame_idx = framewise_output.shape[0] // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
# Get speech-related classes
|
| 429 |
speech_keywords = [
|
|
|
|
| 476 |
noise_prob = np.mean(clip_probs[0, noise_indices])
|
| 477 |
# Adjust speech probability based on noise
|
| 478 |
speech_prob = speech_prob * (1 - noise_prob * 0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
|
| 480 |
else:
|
| 481 |
# Fallback if no speech indices found
|
|
|
|
| 499 |
return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
|
| 500 |
|
| 501 |
class OptimizedAST:
|
| 502 |
+
"""CORRECTED AST with proper 16kHz sample rate and NO CACHE"""
|
| 503 |
def __init__(self):
|
| 504 |
self.model_name = "AST"
|
| 505 |
self.sample_rate = 16000 # AST REQUIRES 16kHz
|
| 506 |
self.model = None
|
| 507 |
self.feature_extractor = None
|
| 508 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 509 |
+
# NO CACHE - removed cache_window and prediction_cache
|
|
|
|
| 510 |
self.load_model()
|
| 511 |
|
| 512 |
def load_model(self):
|
|
|
|
| 535 |
def predict(self, audio: np.ndarray, timestamp: float = 0.0, full_audio: np.ndarray = None) -> VADResult:
|
| 536 |
start_time = time.time()
|
| 537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
if self.model is None or len(audio) == 0:
|
|
|
|
| 539 |
# Enhanced fallback using spectral features
|
| 540 |
if len(audio) > 0:
|
| 541 |
energy = np.sum(audio ** 2)
|
|
|
|
| 544 |
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
|
| 545 |
# Combine multiple features for better speech detection
|
| 546 |
probability = min((energy * 100 + spectral_centroid / 1000) / 2, 1.0)
|
|
|
|
| 547 |
else:
|
| 548 |
probability = min(energy * 50, 1.0)
|
|
|
|
| 549 |
is_speech = probability > 0.25 # Use AST threshold
|
| 550 |
else:
|
| 551 |
probability = 0.0
|
|
|
|
| 553 |
return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
|
| 554 |
|
| 555 |
try:
|
| 556 |
+
# NO CACHE - removed all cache-related code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
if len(audio.shape) > 1:
|
| 559 |
audio = audio.mean(axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
|
| 561 |
+
# CRITICAL: AST uses 16kHz, input is already at 16kHz
|
| 562 |
+
audio_for_ast = audio.astype(np.float32)
|
| 563 |
|
| 564 |
+
# Pad to minimum 1 second if needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
min_samples = int(1.0 * self.sample_rate) # 1 second minimum
|
| 566 |
if len(audio_for_ast) < min_samples:
|
| 567 |
+
audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), mode='constant')
|
| 568 |
+
|
| 569 |
+
# Feature extraction with NO PADDING to 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
inputs = self.feature_extractor(
|
| 571 |
audio_for_ast,
|
| 572 |
sampling_rate=self.sample_rate, # Must be 16kHz
|
| 573 |
return_tensors="pt",
|
| 574 |
+
padding=False, # CHANGED: No padding to 1024
|
| 575 |
+
truncation=False # CHANGED: No truncation
|
|
|
|
| 576 |
)
|
| 577 |
|
|
|
|
|
|
|
| 578 |
# Move inputs to correct device and dtype
|
| 579 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 580 |
if self.device.type == 'cuda' and hasattr(self.model, 'half'):
|
| 581 |
# Convert inputs to FP16 if model is in FP16
|
| 582 |
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
|
| 583 |
|
|
|
|
| 584 |
with torch.no_grad():
|
| 585 |
outputs = self.model(**inputs)
|
| 586 |
logits = outputs.logits
|
| 587 |
probs = torch.sigmoid(logits)
|
| 588 |
|
|
|
|
|
|
|
| 589 |
# Find speech-related classes with enhanced keywords
|
| 590 |
label2id = self.model.config.label2id
|
| 591 |
speech_indices = []
|
|
|
|
| 599 |
if any(word in lbl.lower() for word in speech_keywords):
|
| 600 |
speech_indices.append(idx)
|
| 601 |
|
|
|
|
|
|
|
| 602 |
# Also identify background/noise classes for better discrimination
|
| 603 |
noise_keywords = ['silence', 'white noise', 'background']
|
| 604 |
noise_indices = []
|
|
|
|
| 616 |
noise_prob = torch.mean(probs[0, noise_indices]).item()
|
| 617 |
# Reduce speech probability if high noise/silence detected
|
| 618 |
speech_prob = speech_prob * (1 - noise_prob * 0.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
|
| 620 |
else:
|
| 621 |
# Fallback to energy-based detection with better calibration
|
| 622 |
energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
|
| 623 |
speech_prob = min(energy * 20, 1.0) # Better scaling
|
|
|
|
| 624 |
|
| 625 |
# Use lower threshold specifically for AST (0.25 instead of 0.4)
|
| 626 |
is_speech_ast = speech_prob > 0.25
|
| 627 |
result = VADResult(float(speech_prob), is_speech_ast, self.model_name, time.time()-start_time, timestamp)
|
| 628 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
return result
|
| 630 |
|
| 631 |
except Exception as e:
|
|
|
|
| 637 |
energy = np.sum(audio ** 2) / len(audio) # Normalize by length
|
| 638 |
probability = min(energy * 100, 1.0) # More conservative scaling
|
| 639 |
is_speech = energy > 0.001 # Lower threshold for fallback
|
|
|
|
| 640 |
else:
|
| 641 |
probability = 0.0
|
| 642 |
is_speech = False
|
|
|
|
| 663 |
self.model_windows = {
|
| 664 |
"Silero-VAD": 0.032, # 32ms exactly as required (512 samples)
|
| 665 |
"WebRTC-VAD": 0.03, # 30ms frames (480 samples)
|
| 666 |
+
"E-PANNs": 1.0, # CHANGED from 6.0 to 1.0 for better temporal resolution
|
| 667 |
+
"PANNs": 1.0, # CHANGED from 10.0 to 1.0 for better temporal resolution
|
| 668 |
+
"AST": 1.0 # 1 second for better temporal resolution
|
| 669 |
}
|
| 670 |
|
| 671 |
+
# Model-specific hop sizes for efficiency - INCREASED to 20Hz
|
| 672 |
self.model_hop_sizes = {
|
| 673 |
"Silero-VAD": 0.016, # 16ms hop for Silero (512 samples window)
|
| 674 |
"WebRTC-VAD": 0.03, # 30ms hop for WebRTC (match frame duration)
|
| 675 |
+
"E-PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
|
| 676 |
+
"PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
|
| 677 |
+
"AST": 0.05 # CHANGED from 0.1 to 0.05 for 20Hz
|
| 678 |
}
|
| 679 |
|
| 680 |
# Model-specific thresholds for better detection
|
|
|
|
| 1184 |
|
| 1185 |
model_results = []
|
| 1186 |
|
| 1187 |
+
# CRITICAL FIX: Always extract chunks, both for short and long audio
|
| 1188 |
+
window_count = 0
|
| 1189 |
+
audio_duration = len(processed_audio) / self.processor.sample_rate
|
| 1190 |
+
|
| 1191 |
+
for i in range(0, len(processed_audio), hop_samples):
|
| 1192 |
+
timestamp = i / self.processor.sample_rate
|
| 1193 |
|
| 1194 |
+
# CRITICAL: Extract the chunk centered on this timestamp
|
| 1195 |
+
start_pos = max(0, i - window_samples // 2)
|
| 1196 |
+
end_pos = min(len(processed_audio), start_pos + window_samples)
|
| 1197 |
+
chunk = processed_audio[start_pos:end_pos]
|
| 1198 |
|
| 1199 |
+
# Pad if necessary (with zeros, not repeating)
|
| 1200 |
+
if len(chunk) < window_samples:
|
| 1201 |
+
chunk = np.pad(chunk, (0, window_samples - len(chunk)), mode='constant')
|
| 1202 |
|
| 1203 |
+
if window_count < 3: # Log first 3 windows
|
| 1204 |
+
debug_info.append(f" 🔄 Window {window_count}: t={timestamp:.2f}s, chunk_size={len(chunk)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1205 |
|
| 1206 |
+
# Call predict with the chunk
|
| 1207 |
+
result = self.models[model_name].predict(chunk, timestamp)
|
| 1208 |
+
|
| 1209 |
+
if window_count < 3: # Log first 3 results
|
| 1210 |
+
debug_info.append(f" 📈 Result {window_count}: prob={result.probability:.4f}, speech={result.is_speech}")
|
| 1211 |
|
| 1212 |
+
# Use model-specific threshold
|
| 1213 |
+
result.is_speech = result.probability > model_threshold
|
| 1214 |
+
vad_results.append(result)
|
| 1215 |
+
model_results.append(result)
|
| 1216 |
+
window_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
|
| 1218 |
+
# Stop if we've gone past the audio length
|
| 1219 |
+
if timestamp >= audio_duration:
|
| 1220 |
+
break
|
| 1221 |
+
|
| 1222 |
+
debug_info.append(f" 🎯 Total windows processed: {window_count}")
|
| 1223 |
|
| 1224 |
# Summary for this model
|
| 1225 |
if model_results:
|
|
|
|
| 1396 |
---
|
| 1397 |
**Models**: Silero-VAD, WebRTC-VAD, E-PANNs, PANNs, AST | **Research**: WASPAA 2025 | **Institution**: University of Surrey, CVSSP
|
| 1398 |
|
| 1399 |
+
**Note**: All models now provide high temporal resolution (20Hz) for accurate real-time speech detection.
|
| 1400 |
""")
|
| 1401 |
|
| 1402 |
return interface
|