|
|
|
|
|
|
|
|
import sys |
|
|
import os |
|
|
import re |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from huggingface_hub import login |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
project_root = os.path.dirname(os.path.abspath(__file__)) |
|
|
sys.path.insert(0, project_root) |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
print("'spaces' module imported successfully.") |
|
|
except ImportError: |
|
|
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") |
|
|
class DummySpaces: |
|
|
def GPU(self, *args, **kwargs): |
|
|
def decorator(func): |
|
|
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") |
|
|
return func |
|
|
return decorator |
|
|
spaces = DummySpaces() |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if not HF_TOKEN: |
|
|
raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.") |
|
|
print("--- Logging in to Hugging Face Hub ---") |
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN" |
|
|
print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 |
|
|
print(f"--- Using device: {device}, dtype: {dtype} ---") |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype, trust_remote_code=True).to(device) |
|
|
model.eval() |
|
|
print("--- Model and Tokenizer Loaded Successfully ---") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"FATAL: Could not load components. Error: {e}") |
|
|
|
|
|
|
|
|
def chunk_text(text: str, max_size: int) -> list[str]: |
|
|
if not text: return [] |
|
|
chunks, start_index = [], 0 |
|
|
while start_index < len(text): |
|
|
end_index = start_index + max_size |
|
|
if end_index >= len(text): |
|
|
chunks.append(text[start_index:]); break |
|
|
split_pos = text.rfind('.', start_index, end_index) |
|
|
if split_pos != -1: |
|
|
chunk, start_index = text[start_index : split_pos + 1], split_pos + 1 |
|
|
else: |
|
|
chunk, start_index = text[start_index:end_index], end_index |
|
|
chunks.append(chunk.strip()) |
|
|
return [c for c in chunks if c] |
|
|
|
|
|
def do_translation(text_to_translate: str) -> str: |
|
|
"""Runs a single translation and returns the decoded string.""" |
|
|
if not text_to_translate.strip(): return "" |
|
|
messages = [{"role": "user", "content": text_to_translate}] |
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device) |
|
|
|
|
|
generated_ids_tensor = model.generate(**model_inputs, max_new_tokens=2048, do_sample=True, temperature=0.7, top_p=0.95, top_k=50) |
|
|
|
|
|
input_token_len = model_inputs.input_ids.shape[1] |
|
|
output_ids = generated_ids_tensor[0][input_token_len:].tolist() |
|
|
return tokenizer.decode(output_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
def preprocess_text(text: str) -> str: |
|
|
"""Intelligently cleans text by handling newlines.""" |
|
|
if not text: return "" |
|
|
text = re.sub(r'\n{2,}', ' ', text) |
|
|
text = text.replace('\n', ' ') |
|
|
text = re.sub(r'\s{2,}', ' ', text) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
@torch.no_grad() |
|
|
def translate_with_chunks(input_text: str, chunk_size: int, context_sentences: int, progress=gr.Progress()) -> str: |
|
|
""" |
|
|
Processes chunks using a clear instructional prompt to provide context, |
|
|
preventing overlap and translation direction errors. |
|
|
""" |
|
|
progress(0, desc="Starting...") |
|
|
processed_text = preprocess_text(input_text) |
|
|
if not processed_text: return "Input text is empty. Please enter some text to translate." |
|
|
|
|
|
text_chunks = chunk_text(processed_text, chunk_size) if len(processed_text) > chunk_size else [processed_text] |
|
|
num_chunks = len(text_chunks) |
|
|
print(f"Processing {num_chunks} chunk(s).") |
|
|
|
|
|
all_results = [] |
|
|
|
|
|
polish_context = "" |
|
|
|
|
|
for i, chunk in enumerate(text_chunks): |
|
|
progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}") |
|
|
|
|
|
if not polish_context or context_sentences == 0: |
|
|
|
|
|
prompt = chunk |
|
|
else: |
|
|
|
|
|
prompt = ( |
|
|
"[Previous Translation Context]:\n" |
|
|
f"{polish_context}\n\n" |
|
|
"[New English Text to Translate and Continue]:\n" |
|
|
f"{chunk}" |
|
|
) |
|
|
|
|
|
print(f"--- Prompt for Chunk {i+1} ---\n{prompt}\n--------------------") |
|
|
|
|
|
|
|
|
final_translation_for_chunk = do_translation(prompt) |
|
|
|
|
|
all_results.append(final_translation_for_chunk) |
|
|
print(f"Chunk {i+1} processed successfully.") |
|
|
|
|
|
if context_sentences > 0: |
|
|
|
|
|
|
|
|
sentences = final_translation_for_chunk.split('.') |
|
|
|
|
|
sentences = [s.strip() for s in sentences if s.strip()] |
|
|
if sentences: |
|
|
context_to_take = sentences[-context_sentences:] |
|
|
polish_context = ". ".join(context_to_take) + "." |
|
|
|
|
|
full_output = " ".join(all_results) |
|
|
progress(1.0, desc="Done!") |
|
|
return full_output |
|
|
|
|
|
|
|
|
print("\n--- Initializing Gradio Interface ---") |
|
|
app = gr.Interface( |
|
|
fn=translate_with_chunks, |
|
|
inputs=[ |
|
|
gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."), |
|
|
gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Character Chunk Size"), |
|
|
gr.Slider( |
|
|
minimum=0, |
|
|
maximum=5, |
|
|
value=2, |
|
|
step=1, |
|
|
label="Context Overlap (Sentences)", |
|
|
info="Number of previous translated (Polish) sentences to provide as context. The most reliable method." |
|
|
) |
|
|
], |
|
|
outputs=gr.Textbox(lines=15, label="Model Output", interactive=False), |
|
|
title="ERNIE 4.5 Context-Aware Translator", |
|
|
description="Processes long text using a robust instructional prompt to ensure high-quality, consistent translations.", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.queue().launch() |