Gabriel Bibbó commited on
Commit ·
5bbaead
1
Parent(s): bcae560
Hotfix: Restore basic functionality - fix AST saturation and PANNs execution
Browse files
app.py
CHANGED
|
@@ -362,6 +362,8 @@ class OptimizedAST:
|
|
| 362 |
self.model = None
|
| 363 |
self.feature_extractor = None
|
| 364 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
| 365 |
self.load_model()
|
| 366 |
|
| 367 |
def load_model(self):
|
|
@@ -401,29 +403,52 @@ class OptimizedAST:
|
|
| 401 |
return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
|
| 402 |
|
| 403 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
if len(audio.shape) > 1:
|
| 405 |
audio = audio.mean(axis=1)
|
| 406 |
|
| 407 |
-
# Use longer context for AST -
|
| 408 |
-
if full_audio is not None and len(full_audio) > self.sample_rate:
|
| 409 |
-
# Take
|
| 410 |
center_pos = int(timestamp * self.sample_rate)
|
| 411 |
-
window_size =
|
| 412 |
|
| 413 |
start_pos = max(0, center_pos - window_size)
|
| 414 |
end_pos = min(len(full_audio), center_pos + window_size)
|
| 415 |
|
| 416 |
-
# Ensure we have at least
|
| 417 |
-
if end_pos - start_pos < self.sample_rate:
|
| 418 |
-
end_pos = min(len(full_audio), start_pos + self.sample_rate)
|
|
|
|
|
|
|
| 419 |
|
| 420 |
audio_for_ast = full_audio[start_pos:end_pos]
|
| 421 |
else:
|
| 422 |
audio_for_ast = audio
|
| 423 |
|
| 424 |
-
# Ensure minimum length for AST
|
| 425 |
-
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
# Feature extraction with proper AST parameters
|
| 429 |
inputs = self.feature_extractor(
|
|
@@ -452,23 +477,33 @@ class OptimizedAST:
|
|
| 452 |
|
| 453 |
if speech_indices:
|
| 454 |
speech_prob = probs[0, speech_indices].mean().item()
|
| 455 |
-
#
|
| 456 |
if speech_prob < 0.1 and np.sum(audio_for_ast ** 2) > 0.001:
|
| 457 |
-
speech_prob = min(speech_prob *
|
| 458 |
else:
|
| 459 |
-
# Fallback to energy-based detection
|
| 460 |
-
energy = np.sum(audio_for_ast ** 2)
|
| 461 |
-
speech_prob = min(energy *
|
| 462 |
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
except Exception as e:
|
| 466 |
print(f"Error in {self.model_name}: {e}")
|
| 467 |
# Enhanced fallback
|
| 468 |
if len(audio) > 0:
|
| 469 |
-
energy = np.sum(audio ** 2)
|
| 470 |
-
probability = min(energy *
|
| 471 |
-
is_speech = energy > 0.
|
| 472 |
else:
|
| 473 |
probability = 0.0
|
| 474 |
is_speech = False
|
|
@@ -491,6 +526,15 @@ class AudioProcessor:
|
|
| 491 |
self.window_size = 0.064
|
| 492 |
self.hop_size = 0.032
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
self.delay_compensation = 0.0
|
| 495 |
self.correlation_threshold = 0.7
|
| 496 |
|
|
@@ -921,21 +965,24 @@ class VADDemo:
|
|
| 921 |
|
| 922 |
selected_models = list(set([model_a, model_b]))
|
| 923 |
|
| 924 |
-
# Process each window
|
| 925 |
for i in range(0, len(processed_audio) - window_samples, hop_samples):
|
| 926 |
timestamp = i / self.processor.sample_rate
|
| 927 |
chunk = processed_audio[i:i + window_samples]
|
| 928 |
|
| 929 |
for model_name in selected_models:
|
| 930 |
if model_name in self.models:
|
| 931 |
-
#
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
|
|
|
|
|
|
|
|
|
| 939 |
|
| 940 |
delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
|
| 941 |
onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
|
|
@@ -990,10 +1037,6 @@ class VADDemo:
|
|
| 990 |
traceback.print_exc()
|
| 991 |
return None, f"❌ Error: {str(e)}", f"Error details: {traceback.format_exc()}"
|
| 992 |
|
| 993 |
-
# Initialize demo
|
| 994 |
-
print("🎤 Initializing VAD Demo...")
|
| 995 |
-
demo_app = VADDemo()
|
| 996 |
-
|
| 997 |
# ===== GRADIO INTERFACE =====
|
| 998 |
|
| 999 |
def create_interface():
|
|
@@ -1053,7 +1096,7 @@ def create_interface():
|
|
| 1053 |
|
| 1054 |
model_b = gr.Dropdown(
|
| 1055 |
choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
|
| 1056 |
-
value="
|
| 1057 |
label="Model B (Bottom Panel)"
|
| 1058 |
)
|
| 1059 |
|
|
@@ -1103,11 +1146,10 @@ def create_interface():
|
|
| 1103 |
|
| 1104 |
return interface
|
| 1105 |
|
|
|
|
|
|
|
|
|
|
| 1106 |
# Create and launch interface
|
| 1107 |
if __name__ == "__main__":
|
| 1108 |
-
# Initialize demo
|
| 1109 |
-
print("🎤 Initializing VAD Demo...")
|
| 1110 |
-
demo_app = VADDemo()
|
| 1111 |
-
|
| 1112 |
interface = create_interface()
|
| 1113 |
interface.launch(share=True, debug=False)
|
|
|
|
| 362 |
self.model = None
|
| 363 |
self.feature_extractor = None
|
| 364 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 365 |
+
self.prediction_cache = {} # Cache para evitar recálculos
|
| 366 |
+
self.cache_window = 1.0 # Cachear resultados por segundo
|
| 367 |
self.load_model()
|
| 368 |
|
| 369 |
def load_model(self):
|
|
|
|
| 403 |
return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
|
| 404 |
|
| 405 |
try:
|
| 406 |
+
# Cache key based on timestamp rounded to cache window
|
| 407 |
+
cache_key = int(timestamp / self.cache_window)
|
| 408 |
+
|
| 409 |
+
# Check cache first
|
| 410 |
+
if cache_key in self.prediction_cache:
|
| 411 |
+
cached_result = self.prediction_cache[cache_key]
|
| 412 |
+
# Return cached result with updated timestamp
|
| 413 |
+
return VADResult(
|
| 414 |
+
cached_result.probability,
|
| 415 |
+
cached_result.is_speech,
|
| 416 |
+
cached_result.model_name + " (cached)",
|
| 417 |
+
time.time() - start_time,
|
| 418 |
+
timestamp
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
if len(audio.shape) > 1:
|
| 422 |
audio = audio.mean(axis=1)
|
| 423 |
|
| 424 |
+
# Use longer context for AST - preferably 2 seconds
|
| 425 |
+
if full_audio is not None and len(full_audio) >= 2 * self.sample_rate:
|
| 426 |
+
# Take 2-second window centered around current timestamp
|
| 427 |
center_pos = int(timestamp * self.sample_rate)
|
| 428 |
+
window_size = self.sample_rate # 1 second each side
|
| 429 |
|
| 430 |
start_pos = max(0, center_pos - window_size)
|
| 431 |
end_pos = min(len(full_audio), center_pos + window_size)
|
| 432 |
|
| 433 |
+
# Ensure we have at least 2 seconds
|
| 434 |
+
if end_pos - start_pos < 2 * self.sample_rate:
|
| 435 |
+
end_pos = min(len(full_audio), start_pos + 2 * self.sample_rate)
|
| 436 |
+
if end_pos - start_pos < 2 * self.sample_rate:
|
| 437 |
+
start_pos = max(0, end_pos - 2 * self.sample_rate)
|
| 438 |
|
| 439 |
audio_for_ast = full_audio[start_pos:end_pos]
|
| 440 |
else:
|
| 441 |
audio_for_ast = audio
|
| 442 |
|
| 443 |
+
# Ensure minimum length for AST (2 seconds preferred, minimum 1 second)
|
| 444 |
+
min_samples = 2 * self.sample_rate # 2 seconds
|
| 445 |
+
if len(audio_for_ast) < min_samples:
|
| 446 |
+
audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), 'constant')
|
| 447 |
+
|
| 448 |
+
# Truncate if too long (AST can handle up to ~10s, but we'll use 3s max for efficiency)
|
| 449 |
+
max_samples = 3 * self.sample_rate
|
| 450 |
+
if len(audio_for_ast) > max_samples:
|
| 451 |
+
audio_for_ast = audio_for_ast[:max_samples]
|
| 452 |
|
| 453 |
# Feature extraction with proper AST parameters
|
| 454 |
inputs = self.feature_extractor(
|
|
|
|
| 477 |
|
| 478 |
if speech_indices:
|
| 479 |
speech_prob = probs[0, speech_indices].mean().item()
|
| 480 |
+
# Apply more reasonable thresholding for AST
|
| 481 |
if speech_prob < 0.1 and np.sum(audio_for_ast ** 2) > 0.001:
|
| 482 |
+
speech_prob = min(speech_prob * 3, 0.7) # Moderate boost, cap at 0.7
|
| 483 |
else:
|
| 484 |
+
# Fallback to energy-based detection with higher threshold
|
| 485 |
+
energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
|
| 486 |
+
speech_prob = min(energy * 50, 1.0)
|
| 487 |
|
| 488 |
+
result = VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
|
| 489 |
+
|
| 490 |
+
# Cache the result
|
| 491 |
+
self.prediction_cache[cache_key] = result
|
| 492 |
+
|
| 493 |
+
# Clean old cache entries (keep only last 10 seconds)
|
| 494 |
+
cache_keys_to_remove = [k for k in self.prediction_cache.keys() if k < cache_key - 10]
|
| 495 |
+
for k in cache_keys_to_remove:
|
| 496 |
+
del self.prediction_cache[k]
|
| 497 |
+
|
| 498 |
+
return result
|
| 499 |
|
| 500 |
except Exception as e:
|
| 501 |
print(f"Error in {self.model_name}: {e}")
|
| 502 |
# Enhanced fallback
|
| 503 |
if len(audio) > 0:
|
| 504 |
+
energy = np.sum(audio ** 2) / len(audio) # Normalize by length
|
| 505 |
+
probability = min(energy * 100, 1.0) # More conservative scaling
|
| 506 |
+
is_speech = energy > 0.001 # Lower threshold for fallback
|
| 507 |
else:
|
| 508 |
probability = 0.0
|
| 509 |
is_speech = False
|
|
|
|
| 526 |
self.window_size = 0.064
|
| 527 |
self.hop_size = 0.032
|
| 528 |
|
| 529 |
+
# Model-specific hop sizes for efficiency
|
| 530 |
+
self.model_hop_sizes = {
|
| 531 |
+
"Silero-VAD": 0.032,
|
| 532 |
+
"WebRTC-VAD": 0.03,
|
| 533 |
+
"E-PANNs": 1.0,
|
| 534 |
+
"PANNs": 1.0,
|
| 535 |
+
"AST": 1.0 # Process AST only once per second
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
self.delay_compensation = 0.0
|
| 539 |
self.correlation_threshold = 0.7
|
| 540 |
|
|
|
|
| 965 |
|
| 966 |
selected_models = list(set([model_a, model_b]))
|
| 967 |
|
| 968 |
+
# Process each window with model-specific hop sizes for efficiency
|
| 969 |
for i in range(0, len(processed_audio) - window_samples, hop_samples):
|
| 970 |
timestamp = i / self.processor.sample_rate
|
| 971 |
chunk = processed_audio[i:i + window_samples]
|
| 972 |
|
| 973 |
for model_name in selected_models:
|
| 974 |
if model_name in self.models:
|
| 975 |
+
# Check if we should process this model at this timestamp
|
| 976 |
+
model_hop = self.processor.model_hop_sizes.get(model_name, self.processor.hop_size)
|
| 977 |
+
if i % int(model_hop * self.processor.sample_rate) == 0:
|
| 978 |
+
# Special handling for AST - pass full audio for context
|
| 979 |
+
if model_name == 'AST':
|
| 980 |
+
result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
|
| 981 |
+
else:
|
| 982 |
+
result = self.models[model_name].predict(chunk, timestamp)
|
| 983 |
+
|
| 984 |
+
result.is_speech = result.probability > threshold
|
| 985 |
+
vad_results.append(result)
|
| 986 |
|
| 987 |
delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
|
| 988 |
onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
|
|
|
|
| 1037 |
traceback.print_exc()
|
| 1038 |
return None, f"❌ Error: {str(e)}", f"Error details: {traceback.format_exc()}"
|
| 1039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
# ===== GRADIO INTERFACE =====
|
| 1041 |
|
| 1042 |
def create_interface():
|
|
|
|
| 1096 |
|
| 1097 |
model_b = gr.Dropdown(
|
| 1098 |
choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
|
| 1099 |
+
value="AST",
|
| 1100 |
label="Model B (Bottom Panel)"
|
| 1101 |
)
|
| 1102 |
|
|
|
|
| 1146 |
|
| 1147 |
return interface
|
| 1148 |
|
| 1149 |
+
# Initialize demo only once
|
| 1150 |
+
demo_app = VADDemo()
|
| 1151 |
+
|
| 1152 |
# Create and launch interface
|
| 1153 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1154 |
interface = create_interface()
|
| 1155 |
interface.launch(share=True, debug=False)
|