Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2024 Yiwei Guo | |
| """ Run VC inference with trained model """ | |
| import vec2wav2 | |
| from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor | |
| from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor | |
| # from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor | |
| import torch | |
| import logging | |
| import argparse | |
| from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k | |
| import soundfile as sf | |
| import yaml | |
| import os | |
| def configure_logging(verbose): | |
| if verbose: | |
| logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG) | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| logging.basicConfig(level=logging.DEBUG) | |
| else: | |
| logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR) | |
| logging.getLogger().setLevel(logging.ERROR) | |
| logging.basicConfig(level=logging.ERROR) | |
| script_logger = logging.getLogger("script_logger") | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s')) | |
| script_logger.addHandler(handler) | |
| script_logger.setLevel(logging.INFO) | |
| script_logger.propagate = False | |
| return script_logger | |
| def vc_args(): | |
| parser = argparse.ArgumentParser() | |
| # required arguments | |
| parser.add_argument("-s", "--source", default="examples/source.wav", type=str, | |
| help="source wav path") | |
| parser.add_argument("-t", "--target", default="examples/target.wav", type=str, | |
| help="target speaker prompt path") | |
| parser.add_argument("-o", "--output", default="output.wav", type=str, | |
| help="path of the output wav file") | |
| # optional arguments | |
| parser.add_argument("--expdir", default="pretrained/", type=str, | |
| help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.") | |
| parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.") | |
| parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str, | |
| help="checkpoint or model flag of input token extractor") | |
| parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str, | |
| help="checkpoint or model flag of speaker prompt extractor") | |
| parser.add_argument("--prompt-output-layer", default=6, type=int, | |
| help="output layer when prompt is extracted from WavLM.") | |
| parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") | |
| args = parser.parse_args() | |
| return args | |
| class VoiceConverter: | |
| def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt", | |
| prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6, | |
| checkpoint=None, script_logger=None): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.script_logger = script_logger | |
| self.log_if_possible(f"Using device: {self.device}") | |
| # set up token extractor | |
| self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device) | |
| feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device) | |
| self.feat_codebook = feat_codebook | |
| self.feat_codebook_numgroups = feat_codebook_numgroups | |
| self.log_if_possible(f"Successfully set up token extractor from {token_extractor}") | |
| # set up prompt extractor | |
| self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer) | |
| self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}") | |
| # load VC model | |
| self.config_path = os.path.join(expdir, "config.yml") | |
| with open(self.config_path) as f: | |
| self.config = yaml.load(f, Loader=yaml.Loader) | |
| if checkpoint is not None: | |
| checkpoint = os.path.join(expdir, checkpoint) | |
| else: | |
| checkpoint = os.path.join(expdir, "generator.ckpt") | |
| self.model = load_model(checkpoint, self.config) | |
| self.log_if_possible(f"Successfully set up VC model from {checkpoint}") | |
| self.model.backend.remove_weight_norm() | |
| self.model.eval().to(self.device) | |
| def voice_conversion(self, source_audio, target_audio, output_path="output.wav"): | |
| self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}") | |
| source_wav = read_wav_16k(source_audio) | |
| target_wav = read_wav_16k(target_audio) | |
| vq_idx = self.token_extractor.extract(source_wav).long().to(self.device) | |
| vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0) | |
| prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device) | |
| converted = self.model.inference(vqvec, prompt)[-1].view(-1) | |
| sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate']) | |
| self.log_if_possible(f"Saved audio file to {output_path}") | |
| return output_path | |
| def log_if_possible(self, msg): | |
| if self.script_logger is not None: | |
| self.script_logger.info(msg) | |
| if __name__ == "__main__": | |
| args = vc_args() | |
| script_logger = configure_logging(args.verbose) | |
| source_wav = read_wav_16k(args.source) | |
| target_prompt = read_wav_16k(args.target) | |
| with torch.no_grad(): | |
| voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor, | |
| prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer, | |
| checkpoint=args.checkpoint, script_logger=script_logger) | |
| voice_converter.voice_conversion(args.source, args.target, args.output) | |