Spaces:
Runtime error
Runtime error
Commit
·
9447bbc
1
Parent(s):
1dbd316
fix: add offload folder for Hugging Face dispatch
Browse files- model/load_model.py +10 -11
model/load_model.py
CHANGED
|
@@ -5,34 +5,33 @@ import torch
|
|
| 5 |
from PIL import ImageDraw
|
| 6 |
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
|
| 7 |
|
| 8 |
-
#
|
| 9 |
pretrained_model_id = "google/paligemma2-3b-pt-224"
|
| 10 |
finetuned_model_id = "pyimagesearch/brain-tumor-od-finetuned-paligemma2"
|
| 11 |
token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
os.makedirs(offload_folder, exist_ok=True)
|
| 16 |
|
| 17 |
-
#
|
| 18 |
processor = PaliGemmaProcessor.from_pretrained(pretrained_model_id, token=token)
|
| 19 |
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
| 20 |
finetuned_model_id,
|
| 21 |
device_map="auto",
|
| 22 |
low_cpu_mem_usage=True,
|
| 23 |
-
offload_folder=
|
| 24 |
token=token
|
| 25 |
)
|
| 26 |
model.eval()
|
| 27 |
|
| 28 |
-
#
|
| 29 |
def clear_memory():
|
| 30 |
gc.collect()
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
torch.cuda.empty_cache()
|
| 33 |
torch.cuda.ipc_collect()
|
| 34 |
|
| 35 |
-
#
|
| 36 |
def parse_multiple_locations(decoded_output):
|
| 37 |
loc_pattern = r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s+([^;]+)"
|
| 38 |
matches = re.findall(loc_pattern, decoded_output)
|
|
@@ -49,17 +48,17 @@ def parse_multiple_locations(decoded_output):
|
|
| 49 |
})
|
| 50 |
return coords_and_labels
|
| 51 |
|
| 52 |
-
#
|
| 53 |
def draw_boxes(image, coords_and_labels):
|
| 54 |
draw = ImageDraw.Draw(image)
|
| 55 |
width, height = image.size
|
| 56 |
for obj in coords_and_labels:
|
| 57 |
-
y1, x1, y2, x2 = [
|
| 58 |
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
|
| 59 |
draw.text((x1, y1), obj['label'], fill="red")
|
| 60 |
return image
|
| 61 |
|
| 62 |
-
#
|
| 63 |
def process_image(image, prompt="detect yes"):
|
| 64 |
if not prompt.startswith("<image>"):
|
| 65 |
prompt = "<image>" + prompt
|
|
|
|
| 5 |
from PIL import ImageDraw
|
| 6 |
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
|
| 7 |
|
| 8 |
+
# ==== CẤU HÌNH ====
|
| 9 |
pretrained_model_id = "google/paligemma2-3b-pt-224"
|
| 10 |
finetuned_model_id = "pyimagesearch/brain-tumor-od-finetuned-paligemma2"
|
| 11 |
token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 12 |
|
| 13 |
+
# ==== TẠO OFFLOAD FOLDER (fix lỗi Hugging Face không tự tạo) ====
|
| 14 |
+
os.makedirs("./offload", exist_ok=True)
|
|
|
|
| 15 |
|
| 16 |
+
# ==== LOAD MODEL & PROCESSOR ====
|
| 17 |
processor = PaliGemmaProcessor.from_pretrained(pretrained_model_id, token=token)
|
| 18 |
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
| 19 |
finetuned_model_id,
|
| 20 |
device_map="auto",
|
| 21 |
low_cpu_mem_usage=True,
|
| 22 |
+
offload_folder="./offload", # ✅ Bắt buộc để tránh lỗi dispatch_model
|
| 23 |
token=token
|
| 24 |
)
|
| 25 |
model.eval()
|
| 26 |
|
| 27 |
+
# ==== DỌN BỘ NHỚ (nếu có GPU) ====
|
| 28 |
def clear_memory():
|
| 29 |
gc.collect()
|
| 30 |
if torch.cuda.is_available():
|
| 31 |
torch.cuda.empty_cache()
|
| 32 |
torch.cuda.ipc_collect()
|
| 33 |
|
| 34 |
+
# ==== PARSE OUTPUT RA BOUNDING BOX ====
|
| 35 |
def parse_multiple_locations(decoded_output):
|
| 36 |
loc_pattern = r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s+([^;]+)"
|
| 37 |
matches = re.findall(loc_pattern, decoded_output)
|
|
|
|
| 48 |
})
|
| 49 |
return coords_and_labels
|
| 50 |
|
| 51 |
+
# ==== VẼ BOX LÊN ẢNH ====
|
| 52 |
def draw_boxes(image, coords_and_labels):
|
| 53 |
draw = ImageDraw.Draw(image)
|
| 54 |
width, height = image.size
|
| 55 |
for obj in coords_and_labels:
|
| 56 |
+
y1, x1, y2, x2 = obj['bbox'][0]*height, obj['bbox'][1]*width, obj['bbox'][2]*height, obj['bbox'][3]*width
|
| 57 |
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
|
| 58 |
draw.text((x1, y1), obj['label'], fill="red")
|
| 59 |
return image
|
| 60 |
|
| 61 |
+
# ==== HÀM CHÍNH GỌI MODEL ====
|
| 62 |
def process_image(image, prompt="detect yes"):
|
| 63 |
if not prompt.startswith("<image>"):
|
| 64 |
prompt = "<image>" + prompt
|