from audiocraft.models import MusicGen from scipy.io.wavfile import write import torch import io class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = MusicGen.get_pretrained("facebook/musicgen-large").to(self.device) self.model.set_generation_params(duration=8) def __call__(self, data): prompt = data.get("inputs", "chill lofi music") wav = self.model.generate([prompt]) # internally uses the model device sample_rate = 32000 audio_tensor = wav[0].to("cpu").numpy() # ensure CPU before saving buffer = io.BytesIO() write(buffer, sample_rate, audio_tensor) buffer.seek(0) return {"audio": buffer.read()}