import spaces import gradio as gr import torch import librosa import numpy as np import soundfile as sf # Replace torchaudio.load / torchaudio.info with soundfile-backed versions. # Why: torchaudio 2.9+ routes torchaudio.load through load_with_torchcodec and # ignores the legacy backend= kwarg. torchcodec is unavailable in this Space, # so any indirect torchaudio.load call (inference.py, dataset_f.py, preprocess.py) # raises "TorchCodec is required for load_with_torchcodec". Bypassing the # backend dispatch entirely is the only stable fix. import torchaudio def _patched_load(filepath, *args, **kwargs): frame_offset = kwargs.pop("frame_offset", 0) num_frames = kwargs.pop("num_frames", -1) channels_first = kwargs.pop("channels_first", True) if len(args) >= 1: frame_offset = args[0] if len(args) >= 2: num_frames = args[1] # librosa decodes wav/flac/mp3/m4a/ogg via soundfile -> audioread fallback. # Returns (channels, time) when mono=False and the file is multi-channel, # else (time,) for mono. Normalize to (channels, time). data, sample_rate = librosa.load(str(filepath), sr=None, mono=False) if data.ndim == 1: data = data[np.newaxis, :] if frame_offset and int(frame_offset) > 0: data = data[:, int(frame_offset):] if num_frames and int(num_frames) > 0: data = data[:, : int(num_frames)] # Length check is on the time axis regardless of layout. # beat_this resamples to 22050 Hz and runs STFT with n_fft=1024 / pad=512; # anything shorter than ~1 s raises # "Padding size should be less than the corresponding input dimension". duration_s = data.shape[-1] / max(sample_rate, 1) if duration_s < 1.5: raise RuntimeError( f"Audio is too short ({duration_s:.2f} s). " "Please upload a clip of at least 2 seconds (a few seconds of music works best)." ) waveform = torch.from_numpy(np.ascontiguousarray(data)).float() # Honor torchaudio's channels_first kwarg. beat_this calls with # channels_first=False expecting (time, channels); previously we always # returned (channels, time), so beat_this's signal.mean(1) collapsed the # time axis into a 1- or 2-element vector and torch.stft crashed with # padding (512, 512) at dimension 2 of input [1, 1, 1]. if not channels_first: waveform = waveform.transpose(0, 1).contiguous() return waveform, sample_rate torchaudio.load = _patched_load class _AudioInfo: __slots__ = ("sample_rate", "num_frames", "num_channels", "bits_per_sample", "encoding") def _patched_info(filepath, *args, **kwargs): info = sf.info(str(filepath)) out = _AudioInfo() out.sample_rate = info.samplerate out.num_frames = info.frames out.num_channels = info.channels out.bits_per_sample = 0 out.encoding = info.format return out torchaudio.info = _patched_info from inference import inference from huggingface_hub import hf_hub_download from pathlib import Path import os token = os.getenv("HF_TOKEN") def download_models_from_hub(): """ Download model checkpoints from Hugging Face Model Hub """ model_dir = Path("checkpoints") model_dir.mkdir(exist_ok=True) models = { "main": "EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt", "backup": "step=003432-val_loss=0.0216-val_acc=0.9963.ckpt", "beat_this": "beat_this_final0.ckpt", } downloaded_models = {} for model_name, filename in models.items(): local_path = model_dir / filename if not local_path.exists(): print(f"π₯ Downloading {model_name} model from Hugging Face Hub...") model_path = hf_hub_download( repo_id="mippia/FST-checkpoints", filename=filename, local_dir=str(model_dir), local_dir_use_symlinks=False, token=token, ) print(f"β {model_name} model downloaded successfully!") downloaded_models[model_name] = str(local_path) else: print(f"β {model_name} model already exists locally") downloaded_models[model_name] = str(local_path) return downloaded_models @spaces.GPU def detect_ai_audio(audio_file): """ Detect whether the uploaded audio file was generated by AI and format the result based on the standardized output. """ if audio_file is None: return "
Confidence: {confidence}%
Fake Probability: {float(fake_prob) * 100:.2f}%
Real Probability: {float(real_prob) * 100:.2f}%
Raw Output: {raw_output}
{tb}"
)
# λ€ν¬λͺ¨λ νΈν CSS
custom_css = """
.gradio-container { min-height: 100vh; }
.main-container { border-radius: 15px !important; margin: 20px auto !important; padding: 30px !important; max-width: 800px; }
h1 { text-align: center !important; font-size: 2.5em !important; font-weight: 700 !important; margin-bottom: 15px !important; }
.gradio-markdown p { text-align: center !important; font-size: 1.1em !important; margin-bottom: 20px !important; }
.upload-container { border-radius: 10px !important; padding: 15px !important; margin-bottom: 20px !important; }
.output-container { border-radius: 10px !important; padding: 15px !important; min-height: 150px !important; }
.gr-button { border-radius: 20px !important; padding: 10px 25px !important; font-weight: 600 !important; transition: all 0.2s ease !important; }
.gr-button:hover { transform: translateY(-2px) !important; }
@media (max-width: 768px) {
h1 { font-size: 2em !important; }
.main-container { margin: 10px !important; padding: 20px !important; }
}
"""
# μ΄κΈ°ν
print("π Starting FST AI Audio Detection App...")
print("π¦ Initializing models...")
models = download_models_from_hub()
if models.get("main"):
print("β
Main model ready for inference")
else:
print("β οΈ Warning: Main model not available, app may not work properly")
# Gradio μΈν°νμ΄μ€
demo = gr.Interface(
fn=detect_ai_audio,
inputs=gr.Audio(
type="filepath", label="Upload Audio File", elem_classes=["upload-container"]
),
outputs=gr.HTML(label="Detection Result", elem_classes=["output-container"]),
title="Fusion Segment Transformer for AI Generated Music Detection",
description="""
Fusion Segment Transformer: Bi-directional attention guided fusion network for AI Generated Music Detection
Authors: Yumin Kim and Seonghyeon Go
Submitted to ICASSP 2026. Detects AI-generated music by modeling full audio segments with content-structure fusion.
β οΈ Note: On Zero GPU environment, processing may take ~30 seconds per audio file.