Finnish-NLP commited on
Commit
d15775e
·
1 Parent(s): 7f575b6

Improve setup for multi-GPU support and fix inference docs

Browse files

- install_dependencies.sh: auto-detect GPU compute capability and
install PyTorch 2.10.0+cu128 for Blackwell (sm_120+) or
2.5.1+cu124 for older GPUs; add cuDNN library path fix
- requirements.txt: remove torch version pins (managed by install
script), add missing huggingface_hub/requests deps, clarify that
chatterbox TTS is vendored in src/
- devcontainer.json: fix name, fix wrong repo path, add git-lfs pull,
auto-clone from HF if starting from blank container, run full setup
in postCreateCommand
- README.md: add Quick Start section with GPU compat notes; fix
inference example (was using nonexistent ChatterboxMultilingualTTS,
now uses correct ChatterboxTTS from src.chatterbox_.tts)

.devcontainer/devcontainer.json CHANGED
@@ -1,21 +1,26 @@
1
  {
2
- "name": "Chatterbox A100 Optimized",
 
 
 
3
  "image": "unsloth/unsloth:2025.10.1-pt2.8.0-cu12.8-llamacpp-integration",
4
-
5
  "forwardPorts": [8888],
6
-
7
  "containerEnv": {
8
  "JUPYTER_PASSWORD": "MASKED_PASSWORD",
9
- "USER_PASSWORD": "unsloth"
 
 
10
  },
11
-
12
  "runArgs": [
13
  "--gpus=all",
14
  "--shm-size=64gb"
15
  ],
16
-
17
  "remoteUser": "root",
18
-
19
  "customizations": {
20
  "vscode": {
21
  "extensions": [
@@ -26,5 +31,9 @@
26
  }
27
  },
28
 
29
- "postCreateCommand": "apt-get update && apt-get install -y git ffmpeg libsndfile1 && chmod -R 777 /workspaces && cd /workspaces/work/chatterbox-finetuning"
 
 
 
 
30
  }
 
1
  {
2
+ "name": "Chatterbox Finnish TTS",
3
+
4
+ // Unsloth image with CUDA 12.8 — supports Blackwell (sm_120+) and older GPUs.
5
+ // install_dependencies.sh selects the right PyTorch build automatically.
6
  "image": "unsloth/unsloth:2025.10.1-pt2.8.0-cu12.8-llamacpp-integration",
7
+
8
  "forwardPorts": [8888],
9
+
10
  "containerEnv": {
11
  "JUPYTER_PASSWORD": "MASKED_PASSWORD",
12
+ "USER_PASSWORD": "unsloth",
13
+ // Optional: set your HuggingFace token here if the repo is private
14
+ "HF_TOKEN": ""
15
  },
16
+
17
  "runArgs": [
18
  "--gpus=all",
19
  "--shm-size=64gb"
20
  ],
21
+
22
  "remoteUser": "root",
23
+
24
  "customizations": {
25
  "vscode": {
26
  "extensions": [
 
31
  }
32
  },
33
 
34
+ // postCreateCommand handles two cases:
35
+ // A) Standard VS Code / Codespace flow: repo is already cloned, just pull LFS weights
36
+ // B) Blank container (e.g. docker run): clones the full repo from HuggingFace first
37
+ // Then in both cases: install dependencies and download pretrained base models.
38
+ "postCreateCommand": "apt-get update -qq && apt-get install -y git-lfs ffmpeg libsndfile1 && git lfs install && if [ ! -f inference_example.py ]; then git clone https://huggingface.co/Finnish-NLP/Chatterbox-Finnish /workspace/Chatterbox-Finnish && cd /workspace/Chatterbox-Finnish; fi && git lfs pull && bash install_dependencies.sh && python setup.py"
39
  }
README.md CHANGED
@@ -101,27 +101,69 @@ We used `sweep_params.py` to identify the "Golden Settings" for the most natural
101
 
102
  ---
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  ## 🏃 Running Inference
105
 
106
  ```python
107
- from src.chatterbox_.mtl_tts import ChatterboxMultilingualTTS
 
 
 
 
 
108
 
109
- # 1. Load the engine
110
- engine = ChatterboxMultilingualTTS.from_local("./pretrained_models", device="cuda")
111
 
112
- # 2. Inject weights (e.g., best_finnish_multilingual_cp986.safetensors)
113
- # engine.t3.load_state_dict(...)
 
 
114
 
115
  # 3. Generate with Finnish-optimized parameters
116
  wav = engine.generate(
117
- text="Suomen kieli on poikkeuksellisen kaunista kuunneltavaa.",
118
- language_id="fi",
119
- audio_prompt_path="path/to/reference.wav",
120
- repetition_penalty=1.5,
121
  temperature=0.8,
122
- exaggeration=0.5,
123
- cfg_weight=0.3
124
  )
 
 
 
 
 
 
 
 
125
  ```
126
 
127
  ---
 
101
 
102
  ---
103
 
104
+ ## 🚀 Quick Start
105
+
106
+ ### Option A — Dev Container (recommended)
107
+
108
+ Open this repo in VS Code with the [Dev Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension. Everything — dependencies, base model weights, GPU detection — is handled automatically by `postCreateCommand`.
109
+
110
+ ### Option B — Manual Setup
111
+
112
+ ```bash
113
+ # 1. Clone (with LFS for model weights)
114
+ git clone https://huggingface.co/Finnish-NLP/Chatterbox-Finnish
115
+ cd Chatterbox-Finnish
116
+
117
+ # 2. Install dependencies (auto-detects your GPU architecture)
118
+ bash install_dependencies.sh
119
+
120
+ # 3. Download pretrained base models from ResembleAI
121
+ python setup.py
122
+
123
+ # 4. Run inference
124
+ python inference_example.py
125
+ ```
126
+
127
+ > **GPU compatibility:** The install script detects your GPU and picks the right PyTorch build automatically:
128
+ > - **Blackwell (sm_120+)** e.g. RTX PRO 6000 → PyTorch 2.10.0 + CUDA 12.8
129
+ > - **Older GPUs (A100, RTX 30/40xx, etc.)** → PyTorch 2.5.1 + CUDA 12.4
130
+
131
+ ---
132
+
133
  ## 🏃 Running Inference
134
 
135
  ```python
136
+ import torch
137
+ import soundfile as sf
138
+ from src.chatterbox_.tts import ChatterboxTTS
139
+ from safetensors.torch import load_file
140
+
141
+ device = "cuda" if torch.cuda.is_available() else "cpu"
142
 
143
+ # 1. Load the base engine
144
+ engine = ChatterboxTTS.from_local("./pretrained_models", device=device)
145
 
146
+ # 2. Inject Finnish fine-tuned weights
147
+ checkpoint = load_file("./models/best_finnish_multilingual_cp986.safetensors")
148
+ t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint.items()}
149
+ engine.t3.load_state_dict(t3_state, strict=False)
150
 
151
  # 3. Generate with Finnish-optimized parameters
152
  wav = engine.generate(
153
+ text="Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä.",
154
+ audio_prompt_path="./samples/reference_finnish.wav",
155
+ repetition_penalty=1.2,
 
156
  temperature=0.8,
157
+ exaggeration=0.6,
 
158
  )
159
+
160
+ sf.write("output.wav", wav.squeeze().cpu().numpy(), engine.sr)
161
+ ```
162
+
163
+ Or just run the included example script directly:
164
+
165
+ ```bash
166
+ python inference_example.py # outputs output_finnish.wav
167
  ```
168
 
169
  ---
install_dependencies.sh CHANGED
@@ -1,33 +1,69 @@
1
  #!/bin/bash
2
- # Chatterbox Finetuning - Dependency Installation Script
3
- # This script ensures correct PyTorch and dependency versions are installed
4
 
5
  set -e # Exit on error
6
 
7
  echo "===================================="
8
- echo "Chatterbox Finetuning Setup"
9
  echo "===================================="
10
 
11
  # Check Python version
12
  PYTHON_VERSION=$(python --version 2>&1 | grep -oP '(?<=Python )\d+\.\d+')
13
  echo "Python version: $PYTHON_VERSION"
14
 
15
- # Uninstall conflicting packages if they exist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  echo ""
17
  echo "Step 1: Removing conflicting packages..."
18
  pip uninstall -y torch torchvision torchaudio xformers flash-attn 2>/dev/null || true
19
 
20
- # Install correct PyTorch version
21
  echo ""
22
- echo "Step 2: Installing PyTorch 2.5.1 with CUDA 12.4..."
23
- pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
 
 
 
 
24
 
25
  # Install xformers
26
  echo ""
27
- echo "Step 3: Installing xformers..."
28
- pip install xformers==0.0.28.post3 --index-url https://download.pytorch.org/whl/cu124
29
 
30
- # Install torchao (compatible version)
31
  echo ""
32
  echo "Step 4: Installing torchao..."
33
  pip install torchao==0.6.1
@@ -37,6 +73,31 @@ echo ""
37
  echo "Step 5: Installing remaining dependencies..."
38
  pip install -r requirements.txt
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Verify installation
41
  echo ""
42
  echo "===================================="
@@ -46,12 +107,14 @@ python -c "
46
  import torch
47
  import xformers
48
  import transformers
49
- print(f'PyTorch: {torch.__version__}')
50
- print(f'xformers: {xformers.__version__}')
51
- print(f'Transformers: {transformers.__version__}')
52
- print(f'CUDA available: {torch.cuda.is_available()}')
53
  if torch.cuda.is_available():
54
- print(f'CUDA version: {torch.version.cuda}')
 
 
55
  "
56
 
57
  echo ""
@@ -60,6 +123,7 @@ echo "Installation complete!"
60
  echo "===================================="
61
  echo ""
62
  echo "Next steps:"
63
- echo "1. Run: python setup.py (to download pretrained models)"
64
- echo "2. Run: python train.py (to start training)"
65
- echo ""
 
 
1
  #!/bin/bash
2
+ # Chatterbox Finnish TTS - Dependency Installation Script
3
+ # Automatically selects the correct PyTorch/CUDA version for your GPU.
4
 
5
  set -e # Exit on error
6
 
7
  echo "===================================="
8
+ echo "Chatterbox Finnish TTS Setup"
9
  echo "===================================="
10
 
11
  # Check Python version
12
  PYTHON_VERSION=$(python --version 2>&1 | grep -oP '(?<=Python )\d+\.\d+')
13
  echo "Python version: $PYTHON_VERSION"
14
 
15
+ # Detect GPU compute capability and select appropriate PyTorch build
16
+ echo ""
17
+ echo "Detecting GPU architecture..."
18
+ IS_BLACKWELL=$(python -c "
19
+ import subprocess
20
+ try:
21
+ r = subprocess.run(
22
+ ['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader'],
23
+ capture_output=True, text=True
24
+ )
25
+ caps = [float(c.strip()) for c in r.stdout.strip().splitlines() if c.strip()]
26
+ print('1' if caps and max(caps) >= 12.0 else '0')
27
+ except Exception:
28
+ print('0')
29
+ " 2>/dev/null || echo "0")
30
+
31
+ if [ "$IS_BLACKWELL" = "1" ]; then
32
+ echo "Blackwell GPU detected (sm_120+) — using PyTorch 2.10.0 + CUDA 12.8"
33
+ TORCH_VERSION="2.10.0"
34
+ TORCHVISION_VERSION="0.25.0"
35
+ TORCHAUDIO_VERSION="2.10.0"
36
+ CUDA_TAG="cu128"
37
+ XFORMERS_VERSION="0.0.35"
38
+ else
39
+ echo "Pre-Blackwell GPU detected — using PyTorch 2.5.1 + CUDA 12.4"
40
+ TORCH_VERSION="2.5.1"
41
+ TORCHVISION_VERSION="0.20.1"
42
+ TORCHAUDIO_VERSION="2.5.1"
43
+ CUDA_TAG="cu124"
44
+ XFORMERS_VERSION="0.0.28.post3"
45
+ fi
46
+
47
+ # Uninstall conflicting packages
48
  echo ""
49
  echo "Step 1: Removing conflicting packages..."
50
  pip uninstall -y torch torchvision torchaudio xformers flash-attn 2>/dev/null || true
51
 
52
+ # Install PyTorch
53
  echo ""
54
+ echo "Step 2: Installing PyTorch ${TORCH_VERSION} with CUDA ${CUDA_TAG}..."
55
+ pip install \
56
+ torch==${TORCH_VERSION} \
57
+ torchvision==${TORCHVISION_VERSION} \
58
+ torchaudio==${TORCHAUDIO_VERSION} \
59
+ --index-url https://download.pytorch.org/whl/${CUDA_TAG}
60
 
61
  # Install xformers
62
  echo ""
63
+ echo "Step 3: Installing xformers ${XFORMERS_VERSION}..."
64
+ pip install xformers==${XFORMERS_VERSION} --index-url https://download.pytorch.org/whl/${CUDA_TAG}
65
 
66
+ # Install torchao (compatible with both PyTorch versions)
67
  echo ""
68
  echo "Step 4: Installing torchao..."
69
  pip install torchao==0.6.1
 
73
  echo "Step 5: Installing remaining dependencies..."
74
  pip install -r requirements.txt
75
 
76
+ # Fix potential cuDNN conflict: ensure PyTorch's bundled cuDNN takes priority
77
+ echo ""
78
+ echo "Step 6: Configuring cuDNN library path..."
79
+ CUDNN_PATH=$(python -c "
80
+ import os
81
+ try:
82
+ import nvidia.cudnn
83
+ print(os.path.join(os.path.dirname(nvidia.cudnn.__file__), 'lib'))
84
+ except Exception:
85
+ print('')
86
+ " 2>/dev/null)
87
+
88
+ if [ -n "$CUDNN_PATH" ] && [ -d "$CUDNN_PATH" ]; then
89
+ PROFILE_LINE="export LD_LIBRARY_PATH=${CUDNN_PATH}:\$LD_LIBRARY_PATH"
90
+ # Add to ~/.bashrc if not already present
91
+ if ! grep -qF "$CUDNN_PATH" ~/.bashrc 2>/dev/null; then
92
+ echo "$PROFILE_LINE" >> ~/.bashrc
93
+ fi
94
+ # Apply for the current session
95
+ export LD_LIBRARY_PATH="${CUDNN_PATH}:${LD_LIBRARY_PATH}"
96
+ echo "cuDNN path set to: $CUDNN_PATH"
97
+ else
98
+ echo "No bundled cuDNN found — skipping."
99
+ fi
100
+
101
  # Verify installation
102
  echo ""
103
  echo "===================================="
 
107
  import torch
108
  import xformers
109
  import transformers
110
+ print(f' PyTorch: {torch.__version__}')
111
+ print(f' xformers: {xformers.__version__}')
112
+ print(f' Transformers: {transformers.__version__}')
113
+ print(f' CUDA available: {torch.cuda.is_available()}')
114
  if torch.cuda.is_available():
115
+ print(f' CUDA version: {torch.version.cuda}')
116
+ props = torch.cuda.get_device_properties(0)
117
+ print(f' GPU: {props.name} (sm_{props.major}{props.minor})')
118
  "
119
 
120
  echo ""
 
123
  echo "===================================="
124
  echo ""
125
  echo "Next steps:"
126
+ echo "1. Run: python setup.py (download pretrained base models)"
127
+ echo "2. Run: python inference_example.py (run Finnish TTS inference)"
128
+ echo "3. Run: python train.py (optional: start fine-tuning)"
129
+ echo ""
requirements.txt CHANGED
@@ -1,22 +1,23 @@
1
- # Core PyTorch - Using 2.5.1 for stable xformers/flash-attn support
2
- --extra-index-url https://download.pytorch.org/whl/cu124
3
- torch==2.5.1
4
- torchaudio==2.5.1
5
- torchvision==0.20.1
6
 
7
- # Core dependencies with pinned versions for stability
8
  transformers==4.46.3
9
- xformers==0.0.28.post3
10
- torchao==0.6.1
11
  diffusers==0.29.0
12
  peft==0.17.1
13
 
14
- # Chatterbox TTS dependencies
15
- # Note: chatterbox-tts itself is installed via install_dependencies.sh --no-deps
16
- # to avoid strict torch==2.6.0 conflict
 
 
 
 
17
  resemble-perth==1.0.1
18
  conformer==0.3.2
19
  s3tokenizer==0.3.0
 
20
 
21
  # Audio processing
22
  silero-vad==6.2.0
@@ -34,3 +35,4 @@ tensorboard
34
  omegaconf
35
  hf_transfer
36
  gdown
 
 
1
+ # PyTorch is installed separately by install_dependencies.sh,
2
+ # which auto-detects your GPU and picks the right CUDA build.
3
+ # Do not pin torch/torchaudio/torchvision here.
 
 
4
 
5
+ # Transformers stack
6
  transformers==4.46.3
 
 
7
  diffusers==0.29.0
8
  peft==0.17.1
9
 
10
+ # xformers is also installed by install_dependencies.sh (version depends on GPU)
11
+
12
+ # torchao
13
+ torchao==0.6.1
14
+
15
+ # Chatterbox TTS source is bundled under src/chatterbox_/.
16
+ # These are the runtime deps it needs:
17
  resemble-perth==1.0.1
18
  conformer==0.3.2
19
  s3tokenizer==0.3.0
20
+ huggingface_hub
21
 
22
  # Audio processing
23
  silero-vad==6.2.0
 
35
  omegaconf
36
  hf_transfer
37
  gdown
38
+ requests