File size: 6,317 Bytes
bf47026
 
 
 
 
f60d72d
 
bf47026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f60d72d
bf47026
 
00226ca
bf47026
2bd8a6a
f60d72d
bf47026
 
f60d72d
2bd8a6a
 
bf47026
f60d72d
bf47026
 
 
2bd8a6a
bf47026
2bd8a6a
bf47026
 
 
 
 
f60d72d
 
 
 
 
bf47026
f60d72d
bf47026
 
00226ca
 
 
bf47026
 
f60d72d
 
 
 
 
 
 
 
 
 
 
2bd8a6a
bf47026
 
 
f60d72d
2bd8a6a
bf47026
 
 
 
 
 
 
 
a12e1e4
bf47026
 
 
00226ca
 
f60d72d
00226ca
 
f60d72d
00226ca
 
 
 
 
 
 
 
 
2bd8a6a
00226ca
 
 
 
 
 
 
2bd8a6a
bf47026
 
 
f60d72d
2bd8a6a
bf47026
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import sys
import os
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from huggingface_hub import login
from dotenv import load_dotenv

# --- FIX: Add project root to Python's path ---
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)

# --- Updated Spaces import for Zero-GPU compatibility ---
try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces()

# --- Step 1: Hugging Face Authentication ---
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.")
print("--- Logging in to Hugging Face Hub ---")
login(token=HF_TOKEN)

# --- Step 2: Initialize Model and Tokenizer ---
MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN"
print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
print(f"--- Using device: {device}, dtype: {dtype} ---")

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype, trust_remote_code=True).to(device)
    model.eval()
    print("--- Model and Tokenizer Loaded Successfully ---")
except Exception as e:
    raise RuntimeError(f"FATAL: Could not load components. Error: {e}")

# --- Helper Functions ---
def chunk_text(text: str, max_size: int) -> list[str]:
    if not text: return []
    chunks, start_index = [], 0
    while start_index < len(text):
        end_index = start_index + max_size
        if end_index >= len(text):
            chunks.append(text[start_index:]); break
        split_pos = text.rfind('.', start_index, end_index)
        if split_pos != -1:
            chunk, start_index = text[start_index : split_pos + 1], split_pos + 1
        else:
            chunk, start_index = text[start_index:end_index], end_index
        chunks.append(chunk.strip())
    return [c for c in chunks if c]

# --- Step 3: Core Translation Function (Now with Token-by-Token Streaming) ---
@spaces.GPU
@torch.no_grad()
def translate_with_chunks(input_text: str, chunk_size: int, temperature: float, top_p: float, top_k: int, progress=gr.Progress()) -> str:
    """
    Processes text by translating each chunk independently and streams the
    results back token-by-token for a smooth, real-time user experience.
    """
    progress(0, desc="Starting...")
    if not input_text:
        yield "Input text is empty. Please enter some text to translate."
        return

    text_chunks = chunk_text(input_text, chunk_size) if len(input_text) > chunk_size else [input_text]
    num_chunks = len(text_chunks)
    print(f"Processing {num_chunks} independent chunk(s).")

    full_output = ""
    for i, chunk in enumerate(text_chunks):
        progress(0.1 + (i / num_chunks) * 0.9, desc=f"Translating chunk {i+1}/{num_chunks}")

        messages = [{"role": "user", "content": chunk}]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device)

        # Use TextIteratorStreamer for real-time token generation
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

        # Set up generation arguments
        generation_kwargs = dict(
            **model_inputs,
            streamer=streamer,
            max_new_tokens=2048,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k
        )

        # Run the generation in a separate thread to avoid blocking the UI
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        # Yield new tokens as they are generated
        for new_token in streamer:
            full_output += new_token
            yield full_output
        
        # Add a space after each chunk for better readability
        full_output += " "
        yield full_output.strip()

    progress(1.0, desc="Done!")


# --- Step 4: Create and Launch the Gradio App ---
print("\n--- Initializing Gradio Interface ---")
app = gr.Interface(
    fn=translate_with_chunks,
    inputs=[
        gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."),
        gr.Slider(
            minimum=256,
            maximum=2048,
            value=2048,
            step=64,
            label="Character Chunk Size",
            info="Text will be split into chunks of this size for translation."
        ),
        gr.Slider(
            minimum=0.01, # Temperature cannot be 0 for sampling
            maximum=2.0,
            value=0.7,
            step=0.01,
            label="Temperature",
            info="Controls randomness. Higher values mean more random outputs."
        ),
        gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (Nucleus Sampling)",
            info="Selects from tokens with a cumulative probability mass up to this value."
        ),
        gr.Slider(
            minimum=0,
            maximum=100,
            value=50,
            step=1,
            label="Top-k",
            info="Selects from the top 'k' most likely tokens at each step."
        )
    ],
    outputs=gr.Textbox(lines=15, label="Model Output", interactive=False),
    title="ERNIE 4.5 Text Translator (Real-Time Streaming)",
    description="Processes long text by splitting it into independent chunks and streams the translation in real-time.",
    allow_flagging="never"
)

if __name__ == "__main__":
    app.queue().launch()