Spaces:
Sleeping
Sleeping
Commit
·
4056c2c
1
Parent(s):
99c49c6
Upd caching + trans SFT saver
Browse files- scritps/cache_test.py +118 -0
- trans_test.py → scritps/trans_test.py +0 -0
- utils/datasets.py +11 -1
- vi/processing.py +50 -20
- vi/translator.py +6 -2
scritps/cache_test.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script to verify the fixes for HF permissions and Vietnamese translation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# Add the project root to Python path
|
| 12 |
+
project_root = Path(__file__).parent
|
| 13 |
+
sys.path.insert(0, str(project_root))
|
| 14 |
+
|
| 15 |
+
from vi.translator import VietnameseTranslator
|
| 16 |
+
from vi.processing import translate_sft_row, _validate_vi_translation
|
| 17 |
+
from utils.schema import sft_row
|
| 18 |
+
|
| 19 |
+
# Set up logging
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
def test_vietnamese_translation():
|
| 24 |
+
"""Test Vietnamese translation functionality"""
|
| 25 |
+
logger.info("Testing Vietnamese translation...")
|
| 26 |
+
|
| 27 |
+
# Create a sample SFT row
|
| 28 |
+
sample_row = sft_row(
|
| 29 |
+
instruction="Answer the patient's question like a clinician. Be concise and safe.",
|
| 30 |
+
user_input="What are the symptoms of diabetes?",
|
| 31 |
+
output="Common symptoms of diabetes include increased thirst, frequent urination, unexplained weight loss, fatigue, and blurred vision. If you experience these symptoms, please consult a healthcare provider.",
|
| 32 |
+
source="test",
|
| 33 |
+
rid="test_001",
|
| 34 |
+
task="medical_dialogue"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
logger.info(f"Original SFT row: {sample_row}")
|
| 38 |
+
|
| 39 |
+
# Test translation validation
|
| 40 |
+
test_cases = [
|
| 41 |
+
("Hello world", "Xin chào thế giới", True), # Valid Vietnamese
|
| 42 |
+
("Hello world", "Hello world", False), # Same as original (not translated)
|
| 43 |
+
("Hello world", "translation error", False), # Contains error keyword
|
| 44 |
+
("Hello world", "Hi", False), # Too short
|
| 45 |
+
("Hello world", "", False), # Empty
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
logger.info("Testing translation validation...")
|
| 49 |
+
for original, translated, expected in test_cases:
|
| 50 |
+
result = _validate_vi_translation(original, translated)
|
| 51 |
+
status = "✅" if result == expected else "❌"
|
| 52 |
+
logger.info(f"{status} {original} -> {translated}: {result} (expected {expected})")
|
| 53 |
+
|
| 54 |
+
# Test with translator (if available)
|
| 55 |
+
try:
|
| 56 |
+
translator = VietnameseTranslator()
|
| 57 |
+
logger.info("Vietnamese translator initialized successfully")
|
| 58 |
+
|
| 59 |
+
# Try to load the model
|
| 60 |
+
try:
|
| 61 |
+
translator.load_model()
|
| 62 |
+
logger.info("✅ Translation model loaded successfully")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.warning(f"Could not load translation model: {e}")
|
| 65 |
+
logger.info("This is expected if the model is not downloaded yet")
|
| 66 |
+
return
|
| 67 |
+
|
| 68 |
+
# Test translation
|
| 69 |
+
translated_row = translate_sft_row(sample_row, translator)
|
| 70 |
+
logger.info(f"Translated SFT row: {translated_row}")
|
| 71 |
+
|
| 72 |
+
# Check if translation was applied
|
| 73 |
+
original_sft = sample_row["sft"]
|
| 74 |
+
translated_sft = translated_row["sft"]
|
| 75 |
+
|
| 76 |
+
for field in ["instruction", "input", "output"]:
|
| 77 |
+
original_text = original_sft[field]
|
| 78 |
+
translated_text = translated_sft[field]
|
| 79 |
+
|
| 80 |
+
if original_text != translated_text:
|
| 81 |
+
logger.info(f"✅ Field '{field}' was translated")
|
| 82 |
+
logger.info(f" Original: {original_text[:100]}...")
|
| 83 |
+
logger.info(f" Translated: {translated_text[:100]}...")
|
| 84 |
+
else:
|
| 85 |
+
logger.info(f"⚠️ Field '{field}' was not translated (may be due to validation failure)")
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.warning(f"Could not test with actual translator: {e}")
|
| 89 |
+
logger.info("This is expected if the model is not downloaded yet")
|
| 90 |
+
|
| 91 |
+
def test_hf_cache_setup():
|
| 92 |
+
"""Test Hugging Face cache directory setup"""
|
| 93 |
+
logger.info("Testing HF cache setup...")
|
| 94 |
+
|
| 95 |
+
# Test cache directory creation
|
| 96 |
+
cache_dir = os.path.abspath("cache/huggingface")
|
| 97 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
if os.path.exists(cache_dir) and os.access(cache_dir, os.R_OK | os.W_OK):
|
| 100 |
+
logger.info(f"✅ Cache directory {cache_dir} is accessible")
|
| 101 |
+
else:
|
| 102 |
+
logger.error(f"❌ Cache directory {cache_dir} is not accessible")
|
| 103 |
+
|
| 104 |
+
# Test HF_HOME environment variable
|
| 105 |
+
os.environ["HF_HOME"] = cache_dir
|
| 106 |
+
hf_home = os.getenv("HF_HOME")
|
| 107 |
+
if hf_home == cache_dir:
|
| 108 |
+
logger.info(f"✅ HF_HOME environment variable set to {hf_home}")
|
| 109 |
+
else:
|
| 110 |
+
logger.error(f"❌ HF_HOME environment variable not set correctly")
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
logger.info("Starting fix verification tests...")
|
| 114 |
+
|
| 115 |
+
test_hf_cache_setup()
|
| 116 |
+
test_vietnamese_translation()
|
| 117 |
+
|
| 118 |
+
logger.info("Tests completed!")
|
trans_test.py → scritps/trans_test.py
RENAMED
|
File without changes
|
utils/datasets.py
CHANGED
|
@@ -49,12 +49,22 @@ def hf_download_dataset(repo_id: str, filename: str, repo_type: str = "dataset")
|
|
| 49 |
logger.info(
|
| 50 |
f"[HF] Download {repo_id}/{filename} (type={repo_type}) token={'yes' if token else 'no'}"
|
| 51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
path = hf_hub_download(
|
| 53 |
repo_id=repo_id,
|
| 54 |
filename=filename,
|
| 55 |
repo_type=repo_type,
|
| 56 |
token=token,
|
| 57 |
-
local_dir=
|
| 58 |
local_dir_use_symlinks=False
|
| 59 |
)
|
| 60 |
try:
|
|
|
|
| 49 |
logger.info(
|
| 50 |
f"[HF] Download {repo_id}/{filename} (type={repo_type}) token={'yes' if token else 'no'}"
|
| 51 |
)
|
| 52 |
+
|
| 53 |
+
# Set cache directory with proper permissions
|
| 54 |
+
cache_dir = os.path.abspath("cache/hf")
|
| 55 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
# Set HF_HOME to avoid permission issues
|
| 58 |
+
hf_home = os.path.abspath("cache/huggingface")
|
| 59 |
+
os.makedirs(hf_home, exist_ok=True)
|
| 60 |
+
os.environ["HF_HOME"] = hf_home
|
| 61 |
+
|
| 62 |
path = hf_hub_download(
|
| 63 |
repo_id=repo_id,
|
| 64 |
filename=filename,
|
| 65 |
repo_type=repo_type,
|
| 66 |
token=token,
|
| 67 |
+
local_dir=cache_dir,
|
| 68 |
local_dir_use_symlinks=False
|
| 69 |
)
|
| 70 |
try:
|
vi/processing.py
CHANGED
|
@@ -36,28 +36,39 @@ def _validate_vi_translation(original: str, translated: str) -> bool:
|
|
| 36 |
if not translated or not isinstance(translated, str):
|
| 37 |
return False
|
| 38 |
|
| 39 |
-
# Check if translation is too short
|
| 40 |
if len(translated.strip()) < 3:
|
| 41 |
return False
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
english_chars = len(re.findall(r'[a-zA-Z]', translated))
|
| 46 |
-
total_chars = len(re.sub(r'\s', '', translated))
|
| 47 |
-
if total_chars > 0 and english_chars / total_chars > 0.7:
|
| 48 |
return False
|
| 49 |
|
| 50 |
# Check for common translation failure patterns
|
| 51 |
failure_patterns = [
|
| 52 |
-
"translation
|
| 53 |
-
"not available", "not found", "invalid
|
| 54 |
]
|
| 55 |
translated_lower = translated.lower()
|
| 56 |
for pattern in failure_patterns:
|
| 57 |
if pattern in translated_lower:
|
| 58 |
return False
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
|
| 63 |
"""
|
|
@@ -80,17 +91,36 @@ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] =
|
|
| 80 |
text_fields = ["instruction", "input", "output"]
|
| 81 |
|
| 82 |
try:
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
logger.debug(f"Translated SFT row with fields: {text_fields}")
|
| 95 |
return translated_row
|
| 96 |
except Exception as e:
|
|
|
|
| 36 |
if not translated or not isinstance(translated, str):
|
| 37 |
return False
|
| 38 |
|
| 39 |
+
# Check if translation is too short
|
| 40 |
if len(translated.strip()) < 3:
|
| 41 |
return False
|
| 42 |
|
| 43 |
+
# If translation is identical to original, it's not a valid translation
|
| 44 |
+
if translated.strip() == original.strip():
|
|
|
|
|
|
|
|
|
|
| 45 |
return False
|
| 46 |
|
| 47 |
# Check for common translation failure patterns
|
| 48 |
failure_patterns = [
|
| 49 |
+
"translation error", "translation failed", "unable to translate",
|
| 50 |
+
"cannot translate", "not available", "not found", "invalid translation"
|
| 51 |
]
|
| 52 |
translated_lower = translated.lower()
|
| 53 |
for pattern in failure_patterns:
|
| 54 |
if pattern in translated_lower:
|
| 55 |
return False
|
| 56 |
|
| 57 |
+
# Check if translation contains Vietnamese characters (basic check)
|
| 58 |
+
import re
|
| 59 |
+
vietnamese_chars = len(re.findall(r'[àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', translated, re.IGNORECASE))
|
| 60 |
+
total_chars = len(re.sub(r'\s', '', translated))
|
| 61 |
+
|
| 62 |
+
# If there are Vietnamese characters, it's likely a valid translation
|
| 63 |
+
if vietnamese_chars > 0:
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
# If no Vietnamese characters but significantly different from original, accept it
|
| 67 |
+
# (some translations might not have Vietnamese diacritics)
|
| 68 |
+
if len(translated) > len(original) * 0.5 and len(translated) < len(original) * 2.0:
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
return False
|
| 72 |
|
| 73 |
def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
|
| 74 |
"""
|
|
|
|
| 91 |
text_fields = ["instruction", "input", "output"]
|
| 92 |
|
| 93 |
try:
|
| 94 |
+
# Create a copy of the row to avoid modifying the original
|
| 95 |
+
translated_row = row.copy()
|
| 96 |
+
|
| 97 |
+
# Translate the SFT fields directly
|
| 98 |
+
sft_data = row.get("sft", {})
|
| 99 |
+
translated_sft = {}
|
| 100 |
+
|
| 101 |
+
for field in text_fields:
|
| 102 |
+
if field in sft_data and isinstance(sft_data[field], str) and sft_data[field].strip():
|
| 103 |
+
try:
|
| 104 |
+
original = sft_data[field]
|
| 105 |
+
translated = translator.translate_text(original)
|
| 106 |
+
|
| 107 |
+
# Validate and sanitize translated field
|
| 108 |
+
if _validate_vi_translation(original, translated):
|
| 109 |
+
translated_sft[field] = _vi_sanitize_text(translated)
|
| 110 |
+
logger.debug(f"Translated field '{field}': '{original[:50]}...' -> '{translated[:50]}...'")
|
| 111 |
+
else:
|
| 112 |
+
logger.warning(f"Invalid Vietnamese translation for field {field}, keeping original")
|
| 113 |
+
translated_sft[field] = original
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"Failed to translate field '{field}': {e}")
|
| 116 |
+
translated_sft[field] = sft_data[field]
|
| 117 |
+
else:
|
| 118 |
+
# Keep original if field doesn't exist or is empty
|
| 119 |
+
translated_sft[field] = sft_data.get(field, "")
|
| 120 |
+
|
| 121 |
+
# Update the translated row
|
| 122 |
+
translated_row["sft"] = translated_sft
|
| 123 |
+
|
| 124 |
logger.debug(f"Translated SFT row with fields: {text_fields}")
|
| 125 |
return translated_row
|
| 126 |
except Exception as e:
|
vi/translator.py
CHANGED
|
@@ -57,16 +57,20 @@ class VietnameseTranslator:
|
|
| 57 |
logger.info(f"Loading translation model: {self.model_name}")
|
| 58 |
logger.info(f"Loading on device: {self.device}")
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# Load tokenizer
|
| 61 |
self.tokenizer = MarianTokenizer.from_pretrained(
|
| 62 |
self.model_name,
|
| 63 |
-
cache_dir=
|
| 64 |
)
|
| 65 |
|
| 66 |
# Load model
|
| 67 |
self.model = MarianMTModel.from_pretrained(
|
| 68 |
self.model_name,
|
| 69 |
-
cache_dir=
|
| 70 |
)
|
| 71 |
|
| 72 |
# Move model to device
|
|
|
|
| 57 |
logger.info(f"Loading translation model: {self.model_name}")
|
| 58 |
logger.info(f"Loading on device: {self.device}")
|
| 59 |
|
| 60 |
+
# Set up cache directory
|
| 61 |
+
cache_dir = os.getenv("HF_HOME", os.path.abspath("cache/huggingface"))
|
| 62 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 63 |
+
|
| 64 |
# Load tokenizer
|
| 65 |
self.tokenizer = MarianTokenizer.from_pretrained(
|
| 66 |
self.model_name,
|
| 67 |
+
cache_dir=cache_dir
|
| 68 |
)
|
| 69 |
|
| 70 |
# Load model
|
| 71 |
self.model = MarianMTModel.from_pretrained(
|
| 72 |
self.model_name,
|
| 73 |
+
cache_dir=cache_dir
|
| 74 |
)
|
| 75 |
|
| 76 |
# Move model to device
|