X-iZhang commited on
Commit
f9018e6
·
verified ·
1 Parent(s): c2612c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -22
app.py CHANGED
@@ -6,54 +6,67 @@ from ccd import ccd_eval, run_eval
6
  from libra.eval.run_libra import load_model
7
 
8
  # =========================================
9
- # Safe Libra Hook (CPU fallback for no-GPU environments)
10
  # =========================================
 
11
  import libra.model.builder as builder
12
  import libra.eval.run_libra as run_libra
13
 
14
- # --- Patch 1: replace builder.load_pretrained_model ---
15
  _original_load_pretrained_model = builder.load_pretrained_model
16
 
17
  def safe_load_pretrained_model(model_path, model_base=None, model_name=None, **kwargs):
18
  print("[INFO] Hook activated: safe_load_pretrained_model()")
19
 
20
- # 调用原始加载函数
 
 
 
 
 
 
 
 
 
 
 
21
  tokenizer, model, image_processor, context_len = _original_load_pretrained_model(
22
  model_path, model_base, model_name, **kwargs
23
  )
24
 
25
- # 自动检测 GPU / CPU
26
- if torch.cuda.is_available():
27
- device, dtype = "cuda", torch.float16
28
- print("[INFO] GPU detected — using CUDA + float16.")
29
- else:
30
- device, dtype = "cpu", torch.float32
31
- print("[WARN] No GPU found forcing model to CPU (float32).")
32
 
33
- # 尝试迁移视觉塔到正确设备
34
- try:
35
- vision_tower = model.get_vision_tower()
36
- vision_tower.to(device=device, dtype=dtype)
37
- print(f"[INFO] Vision tower moved to {device} ({dtype}).")
38
- except Exception as e:
39
- print(f"[WARN] Could not move vision tower: {e}")
 
 
 
40
 
41
  return tokenizer, model, image_processor, context_len
42
 
43
-
44
  builder.load_pretrained_model = safe_load_pretrained_model
45
 
46
-
47
- # --- Patch 2: replace run_libra.load_model to force using our patched builder ---
48
  def safe_load_model(model_path, model_base=None, model_name=None):
49
  print("[INFO] Hook activated: safe_load_model()")
50
- # 🩵 Libra expects model_name to be a valid string
51
  if model_name is None:
52
  model_name = model_path
53
  return safe_load_pretrained_model(model_path, model_base, model_name)
54
 
55
  run_libra.load_model = safe_load_model
56
- load_model = safe_load_model # rebind for local use in app.py
57
 
58
  # =========================================
59
  # Global Configuration
 
6
  from libra.eval.run_libra import load_model
7
 
8
  # =========================================
9
+ # Safe Libra Hook (CPU fallback + dtype fix)
10
  # =========================================
11
+ import torch
12
  import libra.model.builder as builder
13
  import libra.eval.run_libra as run_libra
14
 
15
+ # 保存原始函数
16
  _original_load_pretrained_model = builder.load_pretrained_model
17
 
18
  def safe_load_pretrained_model(model_path, model_base=None, model_name=None, **kwargs):
19
  print("[INFO] Hook activated: safe_load_pretrained_model()")
20
 
21
+ # ---- 关键修复 1:补全 model_name,避免 .lower() on None ----
22
+ if model_name is None:
23
+ model_name = model_path
24
+
25
+ # ---- 关键修复 2:强制以 CPU 参数调用原函数,彻底绕开 CUDA 初始化 ----
26
+ # 同时把 device_map 也设置为 cpu(避免传 'auto' 被塞进 {"": "auto"})
27
+ kwargs = dict(kwargs) # 避免原 dict 被上层复用
28
+ kwargs.setdefault("device", "cpu")
29
+ kwargs.setdefault("device_map", "cpu")
30
+
31
+ # 注意:原函数内部仍会把 torch_dtype 设为 float16(除非 4/8bit),
32
+ # 但是我们可以在返回后统一上调为 float32。
33
  tokenizer, model, image_processor, context_len = _original_load_pretrained_model(
34
  model_path, model_base, model_name, **kwargs
35
  )
36
 
37
+ # ---- 关键修复 3:CPU 环境统一上调到 float32,稳定运行 ----
38
+ if not torch.cuda.is_available():
39
+ try:
40
+ # 语言模型主体
41
+ model.to(dtype=torch.float32)
42
+ except Exception as e:
43
+ print(f"[WARN] Could not upcast LM to float32: {e}")
44
 
45
+ try:
46
+ # 视觉塔
47
+ vt = model.get_vision_tower()
48
+ vt.to(device="cpu", dtype=torch.float32)
49
+ print("[INFO] Vision tower moved to cpu (float32).")
50
+ except Exception as e:
51
+ print(f"[WARN] Could not move vision_tower to cpu/float32: {e}")
52
+ else:
53
+ # 若有 GPU,保持原来的 float16 + cuda(无需额外处理)
54
+ print("[INFO] GPU available — default CUDA fp16 path is kept.")
55
 
56
  return tokenizer, model, image_processor, context_len
57
 
58
+ # 将 builder 的加载函数替换为安全版
59
  builder.load_pretrained_model = safe_load_pretrained_model
60
 
61
+ # 同时替换 run_libra.load_model,并把本地名也重绑定,确保后续调用走安全版
 
62
  def safe_load_model(model_path, model_base=None, model_name=None):
63
  print("[INFO] Hook activated: safe_load_model()")
 
64
  if model_name is None:
65
  model_name = model_path
66
  return safe_load_pretrained_model(model_path, model_base, model_name)
67
 
68
  run_libra.load_model = safe_load_model
69
+ load_model = safe_load_model # app.py 后续的 load_model() 使用安全版
70
 
71
  # =========================================
72
  # Global Configuration