manoskary's picture
Update app.py
accdb59 verified
"""
Stable Audio Open Gradio Inference App for HuggingFace Spaces
This app provides a simple interface for generating high-quality instrumental music
using Stable Audio Open with the SAO-Instrumental-Finetune model.
Designed to be used as a remote computation tool for WeaveMuse.
Architecture:
- Stable Audio model is loaded OUTSIDE the GPU-decorated function
- Only the inference itself runs on GPU (cost-efficient for HF Spaces Zero GPU)
- Model initialization happens once at startup
"""
import torch
import torchaudio
from einops import rearrange
import gradio as gr
import spaces
import os
import uuid
import numpy as np
# Importing the model-related functions
from stable_audio_tools.inference.generation import generate_diffusion_cond
import json
from stable_audio_tools.models.factory import create_model_from_config
from stable_audio_tools.models.utils import load_ckpt_state_dict
from huggingface_hub import hf_hub_download
def get_pretrained_model(name="santifiorino/SAO-Instrumental-Finetune"):
model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model')
with open(model_config_path) as f:
model_config = json.load(f)
model = create_model_from_config(model_config)
# Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
try:
model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
except Exception as e:
model_ckpt_path = hf_hub_download(name, filename="SAO_Instrumental_Finetune.ckpt", repo_type='model')
model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
return model, model_config
# Load the model outside of the GPU-decorated function
def load_model():
"""
Load the Stable Audio model outside GPU function.
This is called once at startup to download and cache the model.
"""
print("Loading model...")
model, model_config = get_pretrained_model("santifiorino/SAO-Instrumental-Finetune")
print("Model loaded successfully.")
return model, model_config
import numpy as np
import gradio as gr
import spaces
import torch
from einops import rearrange
# --- load once, keep global (don’t reload inside GPU fn) ---
model, model_config = load_model()
model = model.to("cuda")
SAMPLE_RATE = model_config["sample_rate"]
SAMPLE_SIZE = model_config["sample_size"]
@spaces.GPU()
def generate_audio(prompt, seconds_total=30, steps=100, cfg_scale=7):
"""
Returns (sample_rate, waveform) so the API returns raw audio, not a file.
"""
conditioning = [{
"prompt": prompt,
"seconds_start": 0,
"seconds_total": seconds_total
}]
audio = generate_diffusion_cond(
model,
steps=steps,
cfg_scale=cfg_scale,
conditioning=conditioning,
sample_size=SAMPLE_SIZE,
sigma_min=0.3,
sigma_max=500,
sampler_type="dpmpp-3m-sde",
device="cuda",
)
# [B, C, N] -> [C, B*N] -> [N, C] for Gradio
audio = rearrange(audio, "b c n -> c (b n)") # (C, T)
audio = audio.to(torch.float32)
audio = audio / (audio.abs().max() + 1e-12) # peak normalize
audio = (audio.clamp(-1, 1) * 32767).to(torch.int16) # int16
audio_np = audio.cpu().numpy().T # (T, C)
return SAMPLE_RATE, audio_np
# You no longer need save_audio_to_file() or inf_function()
# Wire the function directly and keep output as numpy (no filepaths!)
interface = gr.Interface(
fn=generate_audio,
inputs=[
gr.Textbox(label="Prompt",
placeholder="Describe the instrumental music...",
value="Upbeat rock guitar with drums and bass"),
gr.Slider(0, 47, value=30, label="Duration in Seconds"),
gr.Slider(10, 150, value=100, step=10, label="Number of Diffusion Steps"),
gr.Slider(1, 15, value=7, step=0.1, label="CFG Scale"),
],
outputs=gr.Audio(type="numpy", format="wav", label="Generated Music"),
api_name="generate", # your API endpoint will be /generate (default is /predict)
title="🎸 Stable Audio Instrumental Generator",
description="""
Generate high-quality instrumental music at 44.1kHz from text prompts using the SAO-Instrumental-Finetune model.
**Features:**
- 🎹 Piano, guitar, drums, bass, and orchestral instruments
- 🎵 Various musical genres and styles
- ⚡ High-quality stereo audio
- 🎼 Perfect for music composition and production
**Tips:**
- Be specific about instruments, tempo, and mood
- Higher steps = better quality (recommended: 100-120)
- CFG Scale 7-10 works well for most prompts
""",
examples=[
[
"Energetic rock guitar riff with powerful drums and bass",
30,
100,
7,
],
[
"Smooth jazz piano trio with upright bass and brushed drums",
35,
110,
8,
],
[
"Epic orchestral strings and brass with cinematic percussion",
45,
120,
10,
],
[
"Funky electric bass groove with rhythm guitar and tight drums",
30,
100,
7,
],
[
"Acoustic guitar fingerpicking with soft percussion",
40,
110,
6,
],
[
"Electronic synthesizer pads with ambient textures and subtle beats",
35,
100,
7.5,
],
[
"Classical piano solo with expressive dynamics and sustain pedal",
30,
110,
8,
],
[
"Blues guitar solo with bending notes over a shuffle rhythm section",
30,
100,
7,
],
[
"Latin percussion ensemble with congas, bongos, and timbales",
30,
100,
7,
],
[
"Rock beat played in a treated studio, session drumming on an acoustic kit",
30,
100,
7,
]
],
article="""
---
### About SAO-Instrumental-Finetune
This model is a fine-tuned version of **Stable Audio Open 1.0** specifically trained for instrumental music generation.
**Capabilities:**
- 🎸 **Guitar**: Acoustic, electric, classical, jazz, rock
- 🥁 **Drums**: Rock, jazz, electronic, orchestral percussion
- 🎹 **Piano**: Classical, jazz, modern, ambient
- � **Orchestral**: Strings, brass, woodwinds
- � **Other**: Bass, synthesizers, ethnic instruments
**Technical Details:**
- Model: SAO-Instrumental-Finetune (based on Stable Audio Open 1.0)
- Sample Rate: 44.1kHz (CD quality)
- Max Duration: 47 seconds
- Architecture: Latent diffusion model with conditioning
**Integration:**
This space is designed to work with **WeaveMuse** for AI-assisted music composition.
Use the API endpoint for programmatic access in your music production workflows.
---
*Powered by [Stability AI](https://stability.ai/) and [WeaveMuse](https://github.com/manoskary/weavemuse)*
"""
)
# Launch the Interface
if __name__ == "__main__":
interface.launch()