LiamKhoaLe commited on
Commit
4056c2c
·
1 Parent(s): 99c49c6

Upd caching + trans SFT saver

Browse files
scritps/cache_test.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify the fixes for HF permissions and Vietnamese translation
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ # Add the project root to Python path
12
+ project_root = Path(__file__).parent
13
+ sys.path.insert(0, str(project_root))
14
+
15
+ from vi.translator import VietnameseTranslator
16
+ from vi.processing import translate_sft_row, _validate_vi_translation
17
+ from utils.schema import sft_row
18
+
19
+ # Set up logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ def test_vietnamese_translation():
24
+ """Test Vietnamese translation functionality"""
25
+ logger.info("Testing Vietnamese translation...")
26
+
27
+ # Create a sample SFT row
28
+ sample_row = sft_row(
29
+ instruction="Answer the patient's question like a clinician. Be concise and safe.",
30
+ user_input="What are the symptoms of diabetes?",
31
+ output="Common symptoms of diabetes include increased thirst, frequent urination, unexplained weight loss, fatigue, and blurred vision. If you experience these symptoms, please consult a healthcare provider.",
32
+ source="test",
33
+ rid="test_001",
34
+ task="medical_dialogue"
35
+ )
36
+
37
+ logger.info(f"Original SFT row: {sample_row}")
38
+
39
+ # Test translation validation
40
+ test_cases = [
41
+ ("Hello world", "Xin chào thế giới", True), # Valid Vietnamese
42
+ ("Hello world", "Hello world", False), # Same as original (not translated)
43
+ ("Hello world", "translation error", False), # Contains error keyword
44
+ ("Hello world", "Hi", False), # Too short
45
+ ("Hello world", "", False), # Empty
46
+ ]
47
+
48
+ logger.info("Testing translation validation...")
49
+ for original, translated, expected in test_cases:
50
+ result = _validate_vi_translation(original, translated)
51
+ status = "✅" if result == expected else "❌"
52
+ logger.info(f"{status} {original} -> {translated}: {result} (expected {expected})")
53
+
54
+ # Test with translator (if available)
55
+ try:
56
+ translator = VietnameseTranslator()
57
+ logger.info("Vietnamese translator initialized successfully")
58
+
59
+ # Try to load the model
60
+ try:
61
+ translator.load_model()
62
+ logger.info("✅ Translation model loaded successfully")
63
+ except Exception as e:
64
+ logger.warning(f"Could not load translation model: {e}")
65
+ logger.info("This is expected if the model is not downloaded yet")
66
+ return
67
+
68
+ # Test translation
69
+ translated_row = translate_sft_row(sample_row, translator)
70
+ logger.info(f"Translated SFT row: {translated_row}")
71
+
72
+ # Check if translation was applied
73
+ original_sft = sample_row["sft"]
74
+ translated_sft = translated_row["sft"]
75
+
76
+ for field in ["instruction", "input", "output"]:
77
+ original_text = original_sft[field]
78
+ translated_text = translated_sft[field]
79
+
80
+ if original_text != translated_text:
81
+ logger.info(f"✅ Field '{field}' was translated")
82
+ logger.info(f" Original: {original_text[:100]}...")
83
+ logger.info(f" Translated: {translated_text[:100]}...")
84
+ else:
85
+ logger.info(f"⚠️ Field '{field}' was not translated (may be due to validation failure)")
86
+
87
+ except Exception as e:
88
+ logger.warning(f"Could not test with actual translator: {e}")
89
+ logger.info("This is expected if the model is not downloaded yet")
90
+
91
+ def test_hf_cache_setup():
92
+ """Test Hugging Face cache directory setup"""
93
+ logger.info("Testing HF cache setup...")
94
+
95
+ # Test cache directory creation
96
+ cache_dir = os.path.abspath("cache/huggingface")
97
+ os.makedirs(cache_dir, exist_ok=True)
98
+
99
+ if os.path.exists(cache_dir) and os.access(cache_dir, os.R_OK | os.W_OK):
100
+ logger.info(f"✅ Cache directory {cache_dir} is accessible")
101
+ else:
102
+ logger.error(f"❌ Cache directory {cache_dir} is not accessible")
103
+
104
+ # Test HF_HOME environment variable
105
+ os.environ["HF_HOME"] = cache_dir
106
+ hf_home = os.getenv("HF_HOME")
107
+ if hf_home == cache_dir:
108
+ logger.info(f"✅ HF_HOME environment variable set to {hf_home}")
109
+ else:
110
+ logger.error(f"❌ HF_HOME environment variable not set correctly")
111
+
112
+ if __name__ == "__main__":
113
+ logger.info("Starting fix verification tests...")
114
+
115
+ test_hf_cache_setup()
116
+ test_vietnamese_translation()
117
+
118
+ logger.info("Tests completed!")
trans_test.py → scritps/trans_test.py RENAMED
File without changes
utils/datasets.py CHANGED
@@ -49,12 +49,22 @@ def hf_download_dataset(repo_id: str, filename: str, repo_type: str = "dataset")
49
  logger.info(
50
  f"[HF] Download {repo_id}/{filename} (type={repo_type}) token={'yes' if token else 'no'}"
51
  )
 
 
 
 
 
 
 
 
 
 
52
  path = hf_hub_download(
53
  repo_id=repo_id,
54
  filename=filename,
55
  repo_type=repo_type,
56
  token=token,
57
- local_dir=os.path.abspath("cache/hf"),
58
  local_dir_use_symlinks=False
59
  )
60
  try:
 
49
  logger.info(
50
  f"[HF] Download {repo_id}/{filename} (type={repo_type}) token={'yes' if token else 'no'}"
51
  )
52
+
53
+ # Set cache directory with proper permissions
54
+ cache_dir = os.path.abspath("cache/hf")
55
+ os.makedirs(cache_dir, exist_ok=True)
56
+
57
+ # Set HF_HOME to avoid permission issues
58
+ hf_home = os.path.abspath("cache/huggingface")
59
+ os.makedirs(hf_home, exist_ok=True)
60
+ os.environ["HF_HOME"] = hf_home
61
+
62
  path = hf_hub_download(
63
  repo_id=repo_id,
64
  filename=filename,
65
  repo_type=repo_type,
66
  token=token,
67
+ local_dir=cache_dir,
68
  local_dir_use_symlinks=False
69
  )
70
  try:
vi/processing.py CHANGED
@@ -36,28 +36,39 @@ def _validate_vi_translation(original: str, translated: str) -> bool:
36
  if not translated or not isinstance(translated, str):
37
  return False
38
 
39
- # Check if translation is too short or too different in length
40
  if len(translated.strip()) < 3:
41
  return False
42
 
43
- # Check if translation contains too much English (should be mostly Vietnamese)
44
- import re
45
- english_chars = len(re.findall(r'[a-zA-Z]', translated))
46
- total_chars = len(re.sub(r'\s', '', translated))
47
- if total_chars > 0 and english_chars / total_chars > 0.7:
48
  return False
49
 
50
  # Check for common translation failure patterns
51
  failure_patterns = [
52
- "translation", "error", "failed", "unable", "cannot",
53
- "not available", "not found", "invalid", "error"
54
  ]
55
  translated_lower = translated.lower()
56
  for pattern in failure_patterns:
57
  if pattern in translated_lower:
58
  return False
59
 
60
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
63
  """
@@ -80,17 +91,36 @@ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] =
80
  text_fields = ["instruction", "input", "output"]
81
 
82
  try:
83
- translated_row = translator.translate_dict(row, text_fields)
84
- # Validate and sanitize translated fields
85
- for f in text_fields:
86
- if f in translated_row.get("sft", {}):
87
- original = row.get("sft", {}).get(f, "")
88
- translated = translated_row["sft"][f]
89
- if _validate_vi_translation(original, translated):
90
- translated_row["sft"][f] = _vi_sanitize_text(translated)
91
- else:
92
- logger.warning(f"Invalid Vietnamese translation for field {f}, keeping original")
93
- translated_row["sft"][f] = original
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  logger.debug(f"Translated SFT row with fields: {text_fields}")
95
  return translated_row
96
  except Exception as e:
 
36
  if not translated or not isinstance(translated, str):
37
  return False
38
 
39
+ # Check if translation is too short
40
  if len(translated.strip()) < 3:
41
  return False
42
 
43
+ # If translation is identical to original, it's not a valid translation
44
+ if translated.strip() == original.strip():
 
 
 
45
  return False
46
 
47
  # Check for common translation failure patterns
48
  failure_patterns = [
49
+ "translation error", "translation failed", "unable to translate",
50
+ "cannot translate", "not available", "not found", "invalid translation"
51
  ]
52
  translated_lower = translated.lower()
53
  for pattern in failure_patterns:
54
  if pattern in translated_lower:
55
  return False
56
 
57
+ # Check if translation contains Vietnamese characters (basic check)
58
+ import re
59
+ vietnamese_chars = len(re.findall(r'[àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', translated, re.IGNORECASE))
60
+ total_chars = len(re.sub(r'\s', '', translated))
61
+
62
+ # If there are Vietnamese characters, it's likely a valid translation
63
+ if vietnamese_chars > 0:
64
+ return True
65
+
66
+ # If no Vietnamese characters but significantly different from original, accept it
67
+ # (some translations might not have Vietnamese diacritics)
68
+ if len(translated) > len(original) * 0.5 and len(translated) < len(original) * 2.0:
69
+ return True
70
+
71
+ return False
72
 
73
  def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
74
  """
 
91
  text_fields = ["instruction", "input", "output"]
92
 
93
  try:
94
+ # Create a copy of the row to avoid modifying the original
95
+ translated_row = row.copy()
96
+
97
+ # Translate the SFT fields directly
98
+ sft_data = row.get("sft", {})
99
+ translated_sft = {}
100
+
101
+ for field in text_fields:
102
+ if field in sft_data and isinstance(sft_data[field], str) and sft_data[field].strip():
103
+ try:
104
+ original = sft_data[field]
105
+ translated = translator.translate_text(original)
106
+
107
+ # Validate and sanitize translated field
108
+ if _validate_vi_translation(original, translated):
109
+ translated_sft[field] = _vi_sanitize_text(translated)
110
+ logger.debug(f"Translated field '{field}': '{original[:50]}...' -> '{translated[:50]}...'")
111
+ else:
112
+ logger.warning(f"Invalid Vietnamese translation for field {field}, keeping original")
113
+ translated_sft[field] = original
114
+ except Exception as e:
115
+ logger.error(f"Failed to translate field '{field}': {e}")
116
+ translated_sft[field] = sft_data[field]
117
+ else:
118
+ # Keep original if field doesn't exist or is empty
119
+ translated_sft[field] = sft_data.get(field, "")
120
+
121
+ # Update the translated row
122
+ translated_row["sft"] = translated_sft
123
+
124
  logger.debug(f"Translated SFT row with fields: {text_fields}")
125
  return translated_row
126
  except Exception as e:
vi/translator.py CHANGED
@@ -57,16 +57,20 @@ class VietnameseTranslator:
57
  logger.info(f"Loading translation model: {self.model_name}")
58
  logger.info(f"Loading on device: {self.device}")
59
 
 
 
 
 
60
  # Load tokenizer
61
  self.tokenizer = MarianTokenizer.from_pretrained(
62
  self.model_name,
63
- cache_dir=os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
64
  )
65
 
66
  # Load model
67
  self.model = MarianMTModel.from_pretrained(
68
  self.model_name,
69
- cache_dir=os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
70
  )
71
 
72
  # Move model to device
 
57
  logger.info(f"Loading translation model: {self.model_name}")
58
  logger.info(f"Loading on device: {self.device}")
59
 
60
+ # Set up cache directory
61
+ cache_dir = os.getenv("HF_HOME", os.path.abspath("cache/huggingface"))
62
+ os.makedirs(cache_dir, exist_ok=True)
63
+
64
  # Load tokenizer
65
  self.tokenizer = MarianTokenizer.from_pretrained(
66
  self.model_name,
67
+ cache_dir=cache_dir
68
  )
69
 
70
  # Load model
71
  self.model = MarianMTModel.from_pretrained(
72
  self.model_name,
73
+ cache_dir=cache_dir
74
  )
75
 
76
  # Move model to device