X-iZhang commited on
Commit
b5b8cb6
·
verified ·
1 Parent(s): 842642b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -34
app.py CHANGED
@@ -1,11 +1,15 @@
1
  import os
2
  # Force CPU-only in this process by hiding CUDA devices (set before importing heavy libs)
3
- os.environ.setdefault('CUDA_VISIBLE_DEVICES', '')
 
4
 
5
  import torch
6
  import gradio as gr
7
  import time
8
 
 
 
 
9
  # =========================================
10
  # Safe Libra Hook (CPU fallback + dtype fix)
11
  # This hook must run before any heavyweight libra model-loading occurs.
@@ -19,14 +23,16 @@ _original_load_pretrained_model = getattr(builder, 'load_pretrained_model', None
19
  def safe_load_pretrained_model(model_path, model_base=None, model_name=None, **kwargs):
20
  print("[INFO] Hook activated: safe_load_pretrained_model()")
21
 
22
- # 补全 model_name,避免 .lower() on None
23
  if model_name is None:
24
  model_name = model_path
25
 
26
- # 强制以 CPU 参数调用原函数,尽量避免 CUDA 初始化
27
  kwargs = dict(kwargs)
28
- kwargs.setdefault('device', 'cpu')
29
- kwargs.setdefault('device_map', 'cpu')
 
 
30
 
31
  if _original_load_pretrained_model is None:
32
  raise RuntimeError('Original load_pretrained_model not found in builder')
@@ -50,20 +56,31 @@ def safe_load_pretrained_model(model_path, model_base=None, model_name=None, **k
50
  # propagate other errors
51
  raise
52
 
53
- # CPU 情况下尝试把模型和视觉塔上调到 float32,减少 CPU 上的兼容问题
54
- if not torch.cuda.is_available():
55
- try:
56
- model.to(dtype=torch.float32)
57
- except Exception as e:
58
- print(f"[WARN] Could not upcast LM to float32: {e}")
59
- try:
 
 
 
60
  vt = model.get_vision_tower()
61
- vt.to(device='cpu', dtype=torch.float32)
62
- print('[INFO] Vision tower moved to cpu (float32).')
63
- except Exception as e:
64
- print(f"[WARN] Could not move vision_tower to cpu/float32: {e}")
65
- else:
66
- print('[INFO] GPU available — keeping original device/dtype behavior.')
 
 
 
 
 
 
 
 
67
 
68
  return tokenizer, model, image_processor, context_len
69
 
@@ -80,7 +97,12 @@ def safe_load_model(model_path, model_base=None, model_name=None):
80
 
81
  run_libra.load_model = safe_load_model
82
 
83
- # 现在导入 CCD 与其他被 hook 的符号(导入放在 hook 之后以确保生效)
 
 
 
 
 
84
  from ccd import ccd_eval, run_eval
85
  from libra.eval.run_libra import load_model
86
 
@@ -88,15 +110,15 @@ from libra.eval.run_libra import load_model
88
  # Global Configuration
89
  # =========================================
90
  MODEL_CATALOGUE = {
 
91
  "Libra-v1.0-7B": "X-iZhang/libra-v1.0-7b",
92
- "Libra-v1.0-3B": "X-iZhang/libra-v1.0-3b",
93
  "MAIRA-2": "X-iZhang/libra-maira-2",
94
  "LLaVA-Med-v1.5": "X-iZhang/libra-llava-med-v1.5-mistral-7b",
95
  "LLaVA-Rad": "X-iZhang/libra-llava-rad",
96
  "Med-CXRGen-F": "X-iZhang/Med-CXRGen-F",
97
  "Med-CXRGen-I": "X-iZhang/Med-CXRGen-I"
98
  }
99
- DEFAULT_MODEL_NAME = "MAIRA-2"
100
  _loaded_models = {}
101
 
102
 
@@ -104,13 +126,14 @@ _loaded_models = {}
104
  # Environment Setup
105
  # =========================================
106
  def setup_environment():
107
- if torch.cuda.is_available():
108
- print("🔹 Using GPU:", torch.cuda.get_device_name(0))
109
- else:
110
- print("🔹 Using CPU")
111
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
112
  os.environ['TRANSFORMERS_CACHE'] = './cache'
113
- torch.set_num_threads(4)
 
 
 
 
114
 
115
 
116
  # =========================================
@@ -125,14 +148,27 @@ def load_or_get_model(model_name: str):
125
  return _loaded_models[model_path]
126
 
127
  print(f"🔹 Loading model: {model_name} ({model_path}) ...")
 
128
  try:
 
 
 
 
 
 
129
  with torch.no_grad():
130
  model = load_model(model_path)
 
131
  _loaded_models[model_path] = model
132
  print(f"✅ Loaded successfully: {model_name}")
 
 
 
133
  return model
134
  except Exception as e:
135
  print(f"❌ Error loading model {model_name}: {e}")
 
 
136
  raise
137
 
138
 
@@ -148,19 +184,25 @@ def generate_ccd_description(
148
  beta,
149
  gamma,
150
  use_run_eval,
151
- max_new_tokens
 
152
  ):
153
  """Generate findings using CCD evaluation."""
154
  if not current_img:
155
  return "⚠️ Please upload or select an example image first."
156
 
157
  try:
 
158
  print(f"🔹 Generating description with model: {selected_model_name}")
159
  print(f"🔹 Parameters: alpha={alpha}, beta={beta}, gamma={gamma}")
160
  print(f"🔹 Image path: {current_img}")
161
 
 
162
  model = load_or_get_model(selected_model_name)
 
 
163
  print(f"🔹 Running CCD with {selected_model_name} and expert model {expert_model}...")
 
164
  ccd_output = ccd_eval(
165
  libra_model=model,
166
  image=current_img,
@@ -172,7 +214,10 @@ def generate_ccd_description(
172
  gamma=gamma
173
  )
174
 
 
 
175
  if use_run_eval:
 
176
  baseline_output = run_eval(
177
  libra_model=model,
178
  image=current_img,
@@ -180,11 +225,13 @@ def generate_ccd_description(
180
  max_new_tokens=max_new_tokens,
181
  num_beams=1
182
  )
 
183
  return (
184
  f"### 🩺 CCD Result ({expert_model})\n{ccd_output}\n\n"
185
  f"---\n### ⚖️ Baseline (run_eval)\n{baseline_output[0]}"
186
  )
187
 
 
188
  return f"### 🩺 CCD Result ({expert_model})\n{ccd_output}"
189
 
190
  except Exception:
@@ -281,10 +328,16 @@ def main():
281
  ### [Project Page](https://x-izhang.github.io/CCD/) | [Paper](https://arxiv.org/abs/2509.23379) | [Code](https://github.com/X-iZhang/CCD) | [Models](https://huggingface.co/collections/X-iZhang/libra-6772bfccc6079298a0fa5f8d)
282
 
283
  **🚨 Performance Warning**
284
-
285
- The demo is currently running on **CPU**, and a single inference takes approximately **500 seconds**.
286
- To achieve optimal performance and significantly reduce inference time, **GPU** is required for effective operation.
287
- For more details, please refer to the [launch demo locally](https://github.com/X-iZhang/CCD#gradio-web-interface).
 
 
 
 
 
 
288
  """)
289
 
290
  with gr.Tab("✨ CCD Demo"):
@@ -347,8 +400,8 @@ def main():
347
  gamma = gr.Slider(0, 20, value=10, step=1, label="Gamma")
348
 
349
  with gr.Accordion("Advanced Options", open=False):
350
- max_new_tokens = gr.Slider(10, 256, value=128, step=1, label="Max New Tokens")
351
- use_run_eval = gr.Checkbox(label="Compare with baseline (run_eval)", value=False)
352
 
353
  generate_btn = gr.Button("🚀 Generate", variant="primary")
354
 
@@ -396,7 +449,12 @@ def main():
396
  pass
397
 
398
 
399
- demo.launch()
 
 
 
 
 
400
 
401
 
402
  if __name__ == "__main__":
 
1
  import os
2
  # Force CPU-only in this process by hiding CUDA devices (set before importing heavy libs)
3
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
4
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
5
 
6
  import torch
7
  import gradio as gr
8
  import time
9
 
10
+ # Force CPU device globally by overriding torch.cuda.is_available
11
+ torch.cuda.is_available = lambda: False
12
+
13
  # =========================================
14
  # Safe Libra Hook (CPU fallback + dtype fix)
15
  # This hook must run before any heavyweight libra model-loading occurs.
 
23
  def safe_load_pretrained_model(model_path, model_base=None, model_name=None, **kwargs):
24
  print("[INFO] Hook activated: safe_load_pretrained_model()")
25
 
26
+ # Complete model_name to avoid .lower() on None
27
  if model_name is None:
28
  model_name = model_path
29
 
30
+ # Force CPU parameters when calling original function
31
  kwargs = dict(kwargs)
32
+ kwargs['device'] = 'cpu'
33
+ kwargs['device_map'] = 'cpu'
34
+ kwargs.setdefault('torch_dtype', torch.float32)
35
+ kwargs.setdefault('low_cpu_mem_usage', True)
36
 
37
  if _original_load_pretrained_model is None:
38
  raise RuntimeError('Original load_pretrained_model not found in builder')
 
56
  # propagate other errors
57
  raise
58
 
59
+ # Force all model components to CPU with float32 for compatibility
60
+ print('[INFO] Forcing all components to CPU with float32 dtype...')
61
+ try:
62
+ model = model.to(device='cpu', dtype=torch.float32)
63
+ print('[INFO] Model moved to CPU (float32).')
64
+ except Exception as e:
65
+ print(f"[WARN] Could not move model to cpu/float32: {e}")
66
+
67
+ try:
68
+ if hasattr(model, 'get_vision_tower'):
69
  vt = model.get_vision_tower()
70
+ if vt is not None:
71
+ vt = vt.to(device='cpu', dtype=torch.float32)
72
+ print('[INFO] Vision tower moved to CPU (float32).')
73
+ except Exception as e:
74
+ print(f"[WARN] Could not move vision_tower to cpu/float32: {e}")
75
+
76
+ try:
77
+ if hasattr(model, 'get_model'):
78
+ inner_model = model.get_model()
79
+ if inner_model is not None:
80
+ inner_model = inner_model.to(device='cpu', dtype=torch.float32)
81
+ print('[INFO] Inner model moved to CPU (float32).')
82
+ except Exception as e:
83
+ print(f"[WARN] Could not move inner model to cpu/float32: {e}")
84
 
85
  return tokenizer, model, image_processor, context_len
86
 
 
97
 
98
  run_libra.load_model = safe_load_model
99
 
100
+ # Now import CCD and hook ccd_utils to force CPU for expert models
101
+ import ccd.ccd_utils as ccd_utils_module
102
+ ccd_utils_module._DEVICE = torch.device('cpu')
103
+ print('[INFO] Forced ccd_utils._DEVICE to CPU')
104
+
105
+ # Now import the evaluation functions
106
  from ccd import ccd_eval, run_eval
107
  from libra.eval.run_libra import load_model
108
 
 
110
  # Global Configuration
111
  # =========================================
112
  MODEL_CATALOGUE = {
113
+ "Libra-v1.0-3B (⚡Recommended for CPU)": "X-iZhang/libra-v1.0-3b",
114
  "Libra-v1.0-7B": "X-iZhang/libra-v1.0-7b",
 
115
  "MAIRA-2": "X-iZhang/libra-maira-2",
116
  "LLaVA-Med-v1.5": "X-iZhang/libra-llava-med-v1.5-mistral-7b",
117
  "LLaVA-Rad": "X-iZhang/libra-llava-rad",
118
  "Med-CXRGen-F": "X-iZhang/Med-CXRGen-F",
119
  "Med-CXRGen-I": "X-iZhang/Med-CXRGen-I"
120
  }
121
+ DEFAULT_MODEL_NAME = "Libra-v1.0-3B (⚡Recommended for CPU)"
122
  _loaded_models = {}
123
 
124
 
 
126
  # Environment Setup
127
  # =========================================
128
  def setup_environment():
129
+ print("🔹 Running in CPU-only mode (forced for Hugging Face Spaces)")
 
 
 
130
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
131
  os.environ['TRANSFORMERS_CACHE'] = './cache'
132
+
133
+ # Set number of threads for CPU inference
134
+ num_threads = min(os.cpu_count() or 4, 8)
135
+ torch.set_num_threads(num_threads)
136
+ print(f"🔹 Using {num_threads} CPU threads")
137
 
138
 
139
  # =========================================
 
148
  return _loaded_models[model_path]
149
 
150
  print(f"🔹 Loading model: {model_name} ({model_path}) ...")
151
+ print(f"🔹 This may take 2-5 minutes on CPU, please wait...")
152
  try:
153
+ # Clear cache before loading to maximize available memory
154
+ import gc
155
+ gc.collect()
156
+ if hasattr(torch.cuda, 'empty_cache'):
157
+ torch.cuda.empty_cache()
158
+
159
  with torch.no_grad():
160
  model = load_model(model_path)
161
+
162
  _loaded_models[model_path] = model
163
  print(f"✅ Loaded successfully: {model_name}")
164
+
165
+ # Clean up after loading
166
+ gc.collect()
167
  return model
168
  except Exception as e:
169
  print(f"❌ Error loading model {model_name}: {e}")
170
+ import traceback
171
+ traceback.print_exc()
172
  raise
173
 
174
 
 
184
  beta,
185
  gamma,
186
  use_run_eval,
187
+ max_new_tokens,
188
+ progress=gr.Progress()
189
  ):
190
  """Generate findings using CCD evaluation."""
191
  if not current_img:
192
  return "⚠️ Please upload or select an example image first."
193
 
194
  try:
195
+ progress(0, desc="Starting inference...")
196
  print(f"🔹 Generating description with model: {selected_model_name}")
197
  print(f"🔹 Parameters: alpha={alpha}, beta={beta}, gamma={gamma}")
198
  print(f"🔹 Image path: {current_img}")
199
 
200
+ progress(0.1, desc="Loading model (this may take several minutes on CPU)...")
201
  model = load_or_get_model(selected_model_name)
202
+
203
+ progress(0.3, desc="Running CCD inference (this may take 5-10 minutes on CPU)...")
204
  print(f"🔹 Running CCD with {selected_model_name} and expert model {expert_model}...")
205
+
206
  ccd_output = ccd_eval(
207
  libra_model=model,
208
  image=current_img,
 
214
  gamma=gamma
215
  )
216
 
217
+ progress(0.8, desc="Processing results...")
218
+
219
  if use_run_eval:
220
+ progress(0.85, desc="Running baseline comparison...")
221
  baseline_output = run_eval(
222
  libra_model=model,
223
  image=current_img,
 
225
  max_new_tokens=max_new_tokens,
226
  num_beams=1
227
  )
228
+ progress(1.0, desc="Complete!")
229
  return (
230
  f"### 🩺 CCD Result ({expert_model})\n{ccd_output}\n\n"
231
  f"---\n### ⚖️ Baseline (run_eval)\n{baseline_output[0]}"
232
  )
233
 
234
+ progress(1.0, desc="Complete!")
235
  return f"### 🩺 CCD Result ({expert_model})\n{ccd_output}"
236
 
237
  except Exception:
 
328
  ### [Project Page](https://x-izhang.github.io/CCD/) | [Paper](https://arxiv.org/abs/2509.23379) | [Code](https://github.com/X-iZhang/CCD) | [Models](https://huggingface.co/collections/X-iZhang/libra-6772bfccc6079298a0fa5f8d)
329
 
330
  **🚨 Performance Warning**
331
+
332
+ This demo is running on **CPU-only** mode. A single inference may take **5-10 minutes** depending on the model and parameters.
333
+
334
+ **Recommendations for faster inference:**
335
+ - Use smaller models (Libra-v1.0-3B is faster than 7B models)
336
+ - Reduce `Max New Tokens` to 64-128 (default: 128)
337
+ - Disable baseline comparison
338
+ - For GPU acceleration, please [run the demo locally](https://github.com/X-iZhang/CCD#gradio-web-interface)
339
+
340
+ **Note:** If you see "Connection Lost", please wait - the inference is still running. The results will appear when complete.
341
  """)
342
 
343
  with gr.Tab("✨ CCD Demo"):
 
400
  gamma = gr.Slider(0, 20, value=10, step=1, label="Gamma")
401
 
402
  with gr.Accordion("Advanced Options", open=False):
403
+ max_new_tokens = gr.Slider(10, 256, value=64, step=1, label="Max New Tokens (lower = faster)")
404
+ use_run_eval = gr.Checkbox(label="Compare with baseline (run_eval) [doubles inference time]", value=False)
405
 
406
  generate_btn = gr.Button("🚀 Generate", variant="primary")
407
 
 
449
  pass
450
 
451
 
452
+ # Launch with extended timeout for CPU inference
453
+ demo.queue(max_size=10) # Enable queue for better handling of long-running tasks
454
+ demo.launch(
455
+ max_threads=4, # Limit concurrent requests
456
+ show_error=True # Show detailed errors
457
+ )
458
 
459
 
460
  if __name__ == "__main__":