| import torch | |
| from datasets import load_dataset | |
| from transformers import Pipeline, SpeechT5Processor, SpeechT5HifiGan | |
| class TTSPipeline(Pipeline): | |
| def __init__(self, *args, vocoder=None, processor=None, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| if vocoder is None: | |
| raise ValueError("Must pass a vocoder to the TTSPipeline.") | |
| if processor is None: | |
| raise ValueError("Must pass a processor to the TTSPipeline.") | |
| if isinstance(vocoder, str): | |
| vocoder = SpeechT5HifiGan.from_pretrained(vocoder) | |
| if isinstance(processor, str): | |
| processor = SpeechT5Processor.from_pretrained(processor) | |
| self.processor = processor | |
| self.vocoder = vocoder | |
| def preprocess(self, text, speaker_embeddings=None): | |
| inputs = self.processor(text=text, return_tensors='pt') | |
| if speaker_embeddings is None: | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0) | |
| return {'inputs': inputs, 'speaker_embeddings': speaker_embeddings} | |
| def _forward(self, model_inputs): | |
| inputs = model_inputs['inputs'] | |
| speaker_embeddings = model_inputs['speaker_embeddings'] | |
| with torch.no_grad(): | |
| speech = self.model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=self.vocoder) | |
| return speech | |
| def _sanitize_parameters(self, **pipeline_parameters): | |
| return {}, {}, {} | |
| def postprocess(self, speech): | |
| return speech | |