Nekochu commited on
Commit
a6a2e67
·
verified ·
1 Parent(s): 63873ef

Revert MCP workarounds.. unfixable: @gradio-app fix MCP schema $ref - $defs nested incorrectly!

Browse files
Files changed (1) hide show
  1. app.py +692 -739
app.py CHANGED
@@ -1,739 +1,692 @@
1
- #!/usr/bin/env python3
2
- """SD Dataset Automaker - HF Space (CPU) - Anime character dataset generator for LoRA/fine-tuning."""
3
-
4
- import warnings
5
- warnings.filterwarnings('ignore', category=FutureWarning) # Suppress torch.cuda.amp spam
6
- warnings.filterwarnings('ignore', category=UserWarning, message='.*trust_repo.*')
7
-
8
- import os, re, shutil, zipfile, random, tempfile, argparse, sys
9
- from urllib.parse import quote_plus
10
- from collections import Counter
11
- from pathlib import Path
12
-
13
- from tqdm import tqdm
14
- import gradio as gr
15
- from bs4 import BeautifulSoup
16
- import requests as req_lib
17
- import time
18
- import numpy as np
19
- from PIL import Image
20
- import torch
21
- import torch.nn as nn
22
- from torchvision import models, transforms
23
- from sklearn.metrics.pairwise import pairwise_distances
24
- import onnxruntime as rt
25
- import pandas as pd
26
- import huggingface_hub
27
-
28
- # =============================================================================
29
- # CONFIG
30
- # =============================================================================
31
- EXTS = ('.jpg', '.jpeg', '.png')
32
- MODEL_DIR = Path(__file__).parent.resolve() # Ensure absolute path
33
- YOLO_PATH = MODEL_DIR / "yolov5s_anime.pt"
34
- SIM_PATH = MODEL_DIR / "similarity.pt"
35
- EXAMPLES = [str(MODEL_DIR / f"from_url_spike_spiegel{i}.jpg") for i in range(1, 4)] # absolute paths for gr.Examples
36
- WD_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
37
- TAG_THRESH, BLACKLIST = 0.35, ["bangs", "breasts", "multicolored hair", "gradient hair", "two-tone hair", "virtual youtuber"]
38
- FACE_CONF, FACE_IOU, MIN_FACE, CROP_PAD = 0.5, 0.5, 35, 0.2
39
- SIM_THRESH, BATCH_SZ, FACE_SZ = 32, 16, 224
40
-
41
- # =============================================================================
42
- # UTILS
43
- # =============================================================================
44
- sanitize = lambda s: re.sub(r'[^\w.-]', '', s.replace(" ", "_"))
45
- get_imgs = lambda d: sorted([os.path.join(r,f) for r,_,fs in os.walk(d) for f in fs if f.lower().endswith(EXTS)])
46
- valid_img = lambda p: (lambda i: i.load() or True)(Image.open(p)) if os.path.exists(p) else False
47
-
48
- # HTTP session - mode depends on environment
49
- # CLI (local Windows): cloudscraper bypasses Cloudflare
50
- # HF Spaces: plain requests (cloudscraper fingerprint gets blocked on datacenter IPs)
51
- def init_session(use_cloudscraper=False):
52
- global SESSION, HTTP_CLIENT
53
- if use_cloudscraper:
54
- try:
55
- import cloudscraper
56
- SESSION = cloudscraper.create_scraper()
57
- HTTP_CLIENT = "cloudscraper"
58
- return
59
- except ImportError:
60
- pass # fallback to requests
61
- SESSION = req_lib.Session()
62
- SESSION.headers.update({
63
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
64
- 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8',
65
- 'Accept-Language': 'en-US,en;q=0.5',
66
- 'Accept-Encoding': 'gzip, deflate, br',
67
- 'DNT': '1',
68
- 'Connection': 'keep-alive',
69
- 'Upgrade-Insecure-Requests': '1',
70
- 'Sec-Fetch-Dest': 'document',
71
- 'Sec-Fetch-Mode': 'navigate',
72
- 'Sec-Fetch-Site': 'none',
73
- 'Sec-Fetch-User': '?1',
74
- 'Cache-Control': 'max-age=0',
75
- 'Referer': 'https://fancaps.net/',
76
- })
77
- HTTP_CLIENT = "requests/browser-headers"
78
-
79
- # Default: plain requests (for HF Spaces import)
80
- SESSION = None
81
- HTTP_CLIENT = None
82
- init_session(use_cloudscraper=False)
83
-
84
- # =============================================================================
85
- # SCRAPING
86
- # =============================================================================
87
- def search_fancaps(prompt, anime=True, movies=False, tv=False, log_fn=None):
88
- L = log_fn or print
89
- url = f"https://fancaps.net/search.php?q={quote_plus(prompt)}&submit=Submit"
90
- if anime: url += "&animeCB=Anime"
91
- if movies: url += "&MoviesCB=Movies"
92
- if tv: url += "&TVCB=TV"
93
- L(f" URL: {url}")
94
- try:
95
- resp = SESSION.get(url, timeout=30)
96
- L(f" Status: {resp.status_code}, Size: {len(resp.content)} bytes")
97
- # Log key headers for debugging
98
- cf_ray = resp.headers.get('cf-ray', 'none')
99
- server = resp.headers.get('server', 'unknown')
100
- L(f" Server: {server}, CF-Ray: {cf_ray}")
101
- if resp.status_code != 200:
102
- L(f" ERROR: HTTP {resp.status_code}")
103
- # Show snippet of response to understand the block reason
104
- content_snippet = resp.text[:500].replace('\n', ' ').strip()
105
- L(f" Response snippet: {content_snippet[:200]}...")
106
- return {}
107
- soup = BeautifulSoup(resp.content, "html.parser")
108
- # Debug: check if we got Cloudflare challenge
109
- title = soup.title.string if soup.title else "No title"
110
- L(f" Page title: {title[:50]}")
111
- if "cloudflare" in title.lower() or "challenge" in title.lower():
112
- L(" ERROR: Cloudflare challenge detected!")
113
- return {}
114
- except Exception as e:
115
- L(f" ERROR: {type(e).__name__}: {e}")
116
- return {}
117
- results, cnt = {}, 1
118
- divs = soup.find_all("div", class_="single_post_content")
119
- L(f" Found {len(divs)} content divs")
120
- for div in divs:
121
- if not div.find('h2'): continue
122
- for h2 in div.find_all('h2'):
123
- title = h2.get_text(strip=True).lower()
124
- cat = 'Movies' if 'movie' in title else 'TV' if 'tv' in title else 'Anime' if 'anime' in title else None
125
- if not cat: continue
126
- table = h2.find_next('table')
127
- if not table: continue
128
- results.setdefault(cat, [])
129
- for h4 in table.find_all('h4'):
130
- a = h4.find('a')
131
- if a and a.get('href'):
132
- results[cat].append((a.get_text(strip=True), a['href'], cnt)); cnt += 1
133
- break
134
- L(f" Parsed results: {sum(len(v) for v in results.values())} items in {list(results.keys())}")
135
- return results
136
-
137
- def get_episodes(url, log_fn=None):
138
- L = log_fn or (lambda x: None)
139
- links, page = [], 1
140
- while True:
141
- try:
142
- resp = SESSION.get(f"{url}&page={page}", timeout=20)
143
- L(f" get_episodes page {page}: status={resp.status_code}")
144
- if resp.status_code != 200:
145
- L(f" ERROR: {resp.text[:150]}...")
146
- break
147
- soup = BeautifulSoup(resp.content, "html.parser")
148
- except Exception as e:
149
- L(f" get_episodes ERROR: {type(e).__name__}: {e}")
150
- break
151
- btns = soup.find_all('a', class_='btn btn-block')
152
- if not btns:
153
- L(f" No episode buttons on page {page}")
154
- break
155
- links.extend([("https://fancaps.net" + b['href'] if b['href'].startswith('/') else b['href']) for b in btns if b.get('href')])
156
- L(f" Found {len(btns)} episodes on page {page}, total: {len(links)}")
157
- page += 1
158
- return links or [url]
159
-
160
- def get_frame_names(url, log_fn=None):
161
- L = log_fn or (lambda x: None)
162
- names, page = [], 1
163
- while True:
164
- try:
165
- resp = SESSION.get(f"{url}&page={page}", timeout=20)
166
- if resp.status_code != 200:
167
- L(f" get_frames page {page}: status={resp.status_code}")
168
- break
169
- soup = BeautifulSoup(resp.content, "html.parser")
170
- except Exception as e:
171
- L(f" get_frames ERROR: {type(e).__name__}: {e}")
172
- break
173
- imgs = soup.find_all('img', class_='imageFade')
174
- if not imgs:
175
- if page == 1: L(f" No images found on first page!")
176
- break
177
- names.extend([s.split('/')[-1] for i in imgs if (s := i.get('src')) and s.split('/')[-1] not in names])
178
- pager = soup.select_one('ul.pagination li:last-child a')
179
- if not pager or pager.get('href') in ['#', None]: break
180
- page += 1
181
- L(f" Total frame names: {len(names)}")
182
- return names
183
-
184
- def download(url, folder, name, timeout=10, retries=3):
185
- """Download single image with retry - returns (success, status_code)."""
186
- fp = os.path.join(folder, name)
187
- if os.path.exists(fp): return True, 200
188
- for attempt in range(retries):
189
- try:
190
- r = SESSION.get(url, stream=True, timeout=timeout)
191
- if r.status_code == 200:
192
- with open(fp, 'wb') as f:
193
- for chunk in r.iter_content(16384):
194
- if chunk: f.write(chunk)
195
- return True, 200
196
- if r.status_code == 429: # Rate limit - don't retry immediately
197
- return False, 429
198
- # Other errors - retry
199
- except:
200
- pass
201
- if attempt < retries - 1:
202
- time.sleep(1)
203
- return False, None
204
-
205
- def scrape(name, link, save_dir, max_imgs, progress=None, log_fn=None):
206
- L = log_fn or print
207
- url, folder = "https://fancaps.net" + link, os.path.join(save_dir, sanitize(name))
208
- os.makedirs(folder, exist_ok=True)
209
- section = 'movie' if '/movies/' in link else 'anime' if '/anime/' in link else 'tv'
210
- L(f" [2/8] Scraping: {url}")
211
- L(f" Section: {section}, max: {max_imgs}")
212
- consecutive_429 = 0
213
- max_429 = 3 # Abort after 3 consecutive 429s
214
-
215
- if section == 'movie':
216
- names = get_frame_names(url, log_fn=L)
217
- L(f" Movie frames: {len(names)}")
218
- sampled = random.sample(names, min(max_imgs, len(names))) if names else []
219
- downloaded = 0
220
- for i, n in enumerate(sampled):
221
- if consecutive_429 >= max_429:
222
- L(f" Aborting: {consecutive_429} consecutive 429s")
223
- break
224
- if i > 0: time.sleep(random.uniform(0.3, 0.8)) # Faster delay
225
- try:
226
- if progress and len(sampled) > 0: progress((i+1)/len(sampled), desc=f"Downloading {name[:20]}")
227
- except: pass
228
- success, status = download(f"https://cdni.fancaps.net/file/fancaps-{section}images/{n}", folder, n)
229
- if success:
230
- downloaded += 1
231
- consecutive_429 = 0
232
- elif status == 429:
233
- consecutive_429 += 1
234
- cooldown = 30 * consecutive_429
235
- L(f" 429 rate limit ({consecutive_429}/{max_429}), cooling {cooldown}s...")
236
- time.sleep(cooldown)
237
- else:
238
- consecutive_429 = 0
239
- L(f" Downloaded: {downloaded}/{len(sampled)}")
240
- else:
241
- L(f" Fetching episodes...")
242
- eps = get_episodes(url, log_fn=L)
243
- L(f" Episodes: {len(eps)}")
244
- total = 0
245
- per_ep = max(1, max_imgs // len(eps)) if eps else max_imgs
246
- for i, ep in enumerate(eps):
247
- if total >= max_imgs or consecutive_429 >= max_429: break
248
- names = get_frame_names(ep, log_fn=L)
249
- if not names: continue
250
- ep_dir = os.path.join(folder, f"Ep{i+1}")
251
- os.makedirs(ep_dir, exist_ok=True)
252
- sampled = random.sample(names, min(per_ep, len(names), max_imgs - total))
253
- for j, n in enumerate(sampled):
254
- if consecutive_429 >= max_429: break
255
- if j > 0: time.sleep(random.uniform(0.3, 0.8)) # Faster delay
256
- try:
257
- if progress and max_imgs > 0: progress(total/max_imgs, desc=f"Ep{i+1}")
258
- except: pass # Gradio progress can fail in some contexts
259
- success, status = download(f"https://cdni.fancaps.net/file/fancaps-{section}images/{n}", ep_dir, n)
260
- if success:
261
- total += 1
262
- consecutive_429 = 0
263
- elif status == 429:
264
- consecutive_429 += 1
265
- cooldown = 30 * consecutive_429
266
- L(f" 429 rate limit ({consecutive_429}/{max_429}), cooling {cooldown}s...")
267
- time.sleep(cooldown)
268
- else:
269
- consecutive_429 = 0
270
- L(f" Total downloaded: {total}")
271
-
272
- # =============================================================================
273
- # ML MODELS (cached)
274
- # =============================================================================
275
- _models = {}
276
-
277
- def get_yolo():
278
- if 'yolo' not in _models:
279
- _models['yolo'] = torch.hub.load('ultralytics/yolov5', 'custom', path=str(YOLO_PATH), force_reload=False, verbose=False)
280
- _models['yolo'].conf, _models['yolo'].iou = FACE_CONF, FACE_IOU
281
- return _models['yolo']
282
-
283
- def get_sim():
284
- if 'sim' not in _models:
285
- class SiameseNetwork(nn.Module):
286
- def __init__(self):
287
- super().__init__()
288
- self.base_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
289
- def forward(self, x):
290
- return self.base_model(x) # 1000-class output (trained this way)
291
- m = SiameseNetwork()
292
- # Keep on CPU for consistent distance values across devices
293
- m.load_state_dict(torch.load(str(SIM_PATH), map_location="cpu", weights_only=True))
294
- m.eval()
295
- _models['sim'] = m
296
- return _models['sim']
297
-
298
- def get_tagger():
299
- if 'tag' not in _models:
300
- mp, cp = huggingface_hub.hf_hub_download(WD_REPO, "model.onnx"), huggingface_hub.hf_hub_download(WD_REPO, "selected_tags.csv")
301
- tags = [str(x).replace('_', ' ') for x in pd.read_csv(cp)['name'].tolist()]
302
- sess = rt.InferenceSession(mp, providers=['CPUExecutionProvider'])
303
- _models['tag'] = (sess, tags, sess.get_inputs()[0].shape[1])
304
- return _models['tag']
305
-
306
- # =============================================================================
307
- # PROCESSING
308
- # =============================================================================
309
- def dedup(paths, thresh=0.98):
310
- if not paths: return [], []
311
- m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1); m.fc = nn.Identity(); m.eval()
312
- tf = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([.485,.456,.406],[.229,.224,.225])])
313
- emb, valid = [], []
314
- with torch.no_grad():
315
- for i in range(0, len(paths), 32):
316
- batch = [(tf(Image.open(p).convert('RGB')), p) for p in paths[i:i+32] if valid_img(p)]
317
- if batch:
318
- x = torch.stack([b[0] for b in batch])
319
- emb.append(m(x).numpy()); valid.extend([b[1] for b in batch])
320
- del m
321
- if not emb: return [], []
322
- emb = np.vstack(emb); emb /= np.linalg.norm(emb, axis=1, keepdims=True).clip(1e-8)
323
- sim = emb @ emb.T; np.fill_diagonal(sim, 0)
324
- keep, drop = [], set()
325
- for i in range(len(valid)):
326
- if i not in drop: keep.append(valid[i]); drop.update(j for j in np.where(sim[i] > thresh)[0] if j > i)
327
- return keep, [valid[i] for i in drop]
328
-
329
- def detect_faces(paths, out_dir):
330
- yolo = get_yolo(); os.makedirs(out_dir, exist_ok=True); cnt = 0
331
- for p in paths:
332
- try:
333
- img = Image.open(p).convert('RGB'); w, h = img.size
334
- for j, det in enumerate(yolo(img, size=640).xyxy[0].cpu().numpy()):
335
- x1, y1, x2, y2, conf, _ = det
336
- bw, bh = x2-x1, y2-y1
337
- x1, y1, x2, y2 = max(0, x1-bw*CROP_PAD), max(0, y1-bh*CROP_PAD), min(w, x2+bw*CROP_PAD), min(h, y2+bh*CROP_PAD)
338
- if min(x2-x1, y2-y1) >= MIN_FACE:
339
- img.crop((int(x1), int(y1), int(x2), int(y2))).save(os.path.join(out_dir, f"{Path(p).stem}-{j+1}-{conf:.2f}.jpg"), quality=95)
340
- cnt += 1
341
- except: pass
342
- return cnt
343
-
344
- def face_emb(paths):
345
- if not paths: return np.array([]), []
346
- m = get_sim() # Always on CPU for consistent distances
347
- def pad(img):
348
- t, w, h = FACE_SZ, img.size[0], img.size[1]; r = w/h
349
- nw, nh = (t, int(t/r)) if r > 1 else (int(t*r), t)
350
- img = img.resize((nw, nh), Image.BICUBIC)
351
- out = Image.new('RGB', (t, t), (0,0,0)); out.paste(img, ((t-nw)//2, (t-nh)//2)); return out
352
- tf = transforms.Compose([lambda x: pad(x.convert('RGB') if x.mode == 'RGBA' else x), transforms.ToTensor()])
353
- emb, valid = [], []
354
- with torch.no_grad():
355
- for i in range(0, len(paths), BATCH_SZ):
356
- batch = [(tf(Image.open(p)), p) for p in paths[i:i+BATCH_SZ] if valid_img(p)]
357
- if batch:
358
- x = torch.stack([b[0] for b in batch]) # CPU tensor
359
- emb.append(m(x).numpy())
360
- valid.extend([b[1] for b in batch])
361
- return (np.vstack(emb), valid) if emb else (np.array([]), [])
362
-
363
- def tag(path, act_tag="", char_tag=""):
364
- sess, tags, sz = get_tagger()
365
- img = Image.open(path).convert('RGB'); w, h = img.size
366
- s = min(sz/w, sz/h); nw, nh = int(w*s), int(h*s)
367
- pad = Image.new('RGB', (sz, sz), (255,255,255)); pad.paste(img.resize((nw, nh), Image.BICUBIC), ((sz-nw)//2, (sz-nh)//2))
368
- probs = sess.run(None, {sess.get_inputs()[0].name: np.expand_dims(np.array(pad).astype(np.float32)[:,:,::-1], 0)})[0][0]
369
- found = [tags[i] for i, p in enumerate(probs) if p > TAG_THRESH and tags[i] not in BLACKLIST]
370
- # Prepend activation tag and character tag if provided
371
- prefix = []
372
- if act_tag: prefix.append(act_tag); found = [t for t in found if t != act_tag]
373
- if char_tag: prefix.append(char_tag.replace("_", " ")); found = [t for t in found if t != char_tag.replace("_", " ")]
374
- return prefix + found
375
-
376
- # =============================================================================
377
- # PIPELINE
378
- # =============================================================================
379
- def parse_direct_url(url):
380
- """Parse direct fancaps URL, extract show name and relative link. Returns (name, link) or (None, None)."""
381
- # Match patterns: showimages.php?ID-Name, MovieImages.php?movieid=ID&name=Name, episodeimages.php?ID-Name
382
- patterns = [
383
- r'fancaps\.net/anime/showimages\.php\?(\d+)-([^&/]+)', # anime show
384
- r'fancaps\.net/tv/showimages\.php\?(\d+)-([^&/]+)', # tv show
385
- r'fancaps\.net/movies/MovieImages\.php\?.*?movieid=(\d+)', # movie
386
- ]
387
- for pat in patterns:
388
- m = re.search(pat, url)
389
- if m:
390
- if 'anime' in url: section = 'anime'
391
- elif 'movies' in url: section = 'movies'
392
- else: section = 'tv'
393
- # Extract name from URL (replace underscores with spaces)
394
- name = m.group(2).replace('_', ' ') if len(m.groups()) > 1 else f"Show_{m.group(1)}"
395
- # Build relative link (what scrape() expects)
396
- if section == 'movies':
397
- link = f"/movies/MovieImages.php?movieid={m.group(1)}"
398
- else:
399
- link = f"/{section}/showimages.php?{m.group(1)}-{m.group(2) if len(m.groups()) > 1 else ''}"
400
- return name, link
401
- return None, None
402
-
403
- def run(query, char, examples, max_img, thresh, act_tag, anime, movies, tv, progress=None, cli_mode=False):
404
- log = []
405
- def L(m): log.append(m); print(m)
406
- def prog(val, desc=""):
407
- if progress and not cli_mode: progress(val, desc=desc)
408
-
409
- work = tempfile.mkdtemp(prefix="ds_")
410
- dirs = {k: os.path.join(work, f"{i}_{k}") for i, k in enumerate(['scrapped','filtered','faces','ex_faces','similar','results'], 1)}
411
- for d in dirs.values(): os.makedirs(d, exist_ok=True)
412
- final_zip = None # Track ZIP for cleanup
413
-
414
- try:
415
- L(f"HTTP client: {HTTP_CLIENT}")
416
- t0 = time.time()
417
-
418
- # Check if query is a direct fancaps URL (bypasses search, works on HF Spaces)
419
- if 'fancaps.net' in query and ('showimages.php' in query or 'MovieImages.php' in query):
420
- L(f"[1/8] Direct URL mode")
421
- name, link = parse_direct_url(query)
422
- if not link:
423
- return None, "\n".join(log) + "\n\nCouldn't parse URL!"
424
- item = (name, link, 1)
425
- L(f" Parsed: {name}")
426
- else:
427
- L(f"[1/8] Search: {query}")
428
- prog(0.05, desc="Searching...")
429
- res = search_fancaps(query, anime, movies, tv, log_fn=L)
430
- if not res:
431
- return None, "\n".join(log) + "\n\nSearch blocked! Use direct fancaps URL."
432
- item = next((items[0] for items in res.values() if items), None)
433
- if not item: return None, "No results!"
434
-
435
- show_name = item[0]
436
- if not char: char = sanitize(show_name)
437
- t1 = time.time(); L(f" Found: {show_name} ({t1-t0:.0f}s)"); prog(0.1, desc="Downloading...")
438
-
439
- # [2/8] Scrape
440
- scrape(item[0], item[1], dirs['scrapped'], max_img, progress if not cli_mode else None, log_fn=L)
441
- imgs = get_imgs(dirs['scrapped'])
442
- t2 = time.time(); L(f"[2/8] Downloaded: {len(imgs)} ({t2-t1:.0f}s)")
443
- if not imgs: return None, "No images downloaded!"
444
-
445
- # [3/8] Dedup
446
- prog(0.3, desc="Dedup...")
447
- imgs = [p for p in imgs if valid_img(p)]
448
- kept, rm = dedup(imgs)
449
- for p in kept: shutil.copy(p, os.path.join(dirs['filtered'], os.path.basename(p)))
450
- t3 = time.time(); L(f"[3/8] Dedup: {len(kept)} kept, -{len(rm)} ({t3-t2:.0f}s)")
451
-
452
- # [4/8] Detect faces
453
- prog(0.4, desc="Faces...")
454
- n = detect_faces(get_imgs(dirs['filtered']), dirs['faces'])
455
- t4 = time.time(); L(f"[4/8] Faces: {n} ({t4-t3:.0f}s)")
456
- if n == 0: return None, "No faces detected!"
457
-
458
- # [5/8] Process examples
459
- prog(0.5, desc="Examples...")
460
- ex_paths = [p for p in (examples or []) if p and os.path.exists(p)]
461
- if not ex_paths: ex_paths = [p for p in EXAMPLES if os.path.exists(p)]
462
- if not ex_paths: return None, "No example images!"
463
- n_ex = detect_faces(ex_paths, dirs['ex_faces'])
464
- t5 = time.time(); L(f"[5/8] Examples: {len(ex_paths)} imgs -> {n_ex} faces ({t5-t4:.0f}s)")
465
- if n_ex == 0: return None, "No faces in examples!"
466
-
467
- # [6/8] Match
468
- prog(0.6, desc="Matching...")
469
- f_emb, f_valid = face_emb(get_imgs(dirs['faces']))
470
- e_emb, _ = face_emb(get_imgs(dirs['ex_faces']))
471
- dists = pairwise_distances(f_emb, e_emb, metric='euclidean').min(axis=1)
472
- similar_idx = np.where(dists < thresh)[0]
473
- similar = [f_valid[i] for i in similar_idx]
474
- similar_dists = dists[similar_idx]
475
- t6 = time.time()
476
- L(f"[6/8] Matches: {len(similar)} (thresh={thresh}) ({t6-t5:.0f}s)")
477
- if len(similar_dists) > 0:
478
- L(f" Distances: min={similar_dists.min():.1f}, max={similar_dists.max():.1f}, mean={similar_dists.mean():.1f}")
479
- if not similar: return None, f"No matches! Try threshold > {thresh}"
480
-
481
- # [7/8] Get originals
482
- prog(0.7, desc="Collect...")
483
- origs = set()
484
- orig_to_dist = {}
485
- for i, fp in enumerate(similar):
486
- parts = os.path.basename(fp).rsplit('-', 2)
487
- base = parts[0] if len(parts) >= 3 else Path(fp).stem
488
- for ext in EXTS:
489
- op = os.path.join(dirs['filtered'], base + ext)
490
- if os.path.exists(op):
491
- origs.add(op)
492
- orig_to_dist[os.path.basename(op)] = similar_dists[i]
493
- break
494
- res_dir = os.path.join(work, f"results_{sanitize(char)}")
495
- os.makedirs(res_dir, exist_ok=True)
496
- for p in origs: shutil.copy(p, os.path.join(res_dir, os.path.basename(p)))
497
- t7 = time.time(); L(f"[7/8] Collected: {len(origs)} ({t7-t6:.0f}s)")
498
-
499
- # [8/8] Tag
500
- prog(0.8, desc="Tagging...")
501
- char_tag = char if char != sanitize(show_name) else ""
502
- for p in get_imgs(res_dir):
503
- tags = tag(p, act_tag, char_tag)
504
- with open(os.path.splitext(p)[0] + ".txt", 'w') as f: f.write(", ".join(tags))
505
- t8 = time.time(); L(f"[8/8] Tagged: {len(origs)} ({t8-t7:.0f}s)")
506
-
507
- # Log each image with distance
508
- L(f"\nResults (distance to ref):")
509
- for name, d in sorted(orig_to_dist.items(), key=lambda x: x[1]):
510
- L(f" {name}: {d:.1f}")
511
-
512
- # Zip
513
- prog(0.95, desc="Zipping...")
514
- zp = os.path.join(work, f"{sanitize(char)}_dataset.zip")
515
- with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as z:
516
- for p in get_imgs(res_dir) + [os.path.splitext(p)[0]+".txt" for p in get_imgs(res_dir)]:
517
- if os.path.exists(p): z.write(p, os.path.basename(p))
518
- # Copy ZIP to persistent temp location (Gradio needs file to exist after return)
519
- final_zip = tempfile.NamedTemporaryFile(delete=False, suffix=".zip", prefix=f"{sanitize(char)}_").name
520
- shutil.copy(zp, final_zip)
521
- L(f"\nDone! {len(origs)} images, total {t8-t0:.0f}s"); prog(1.0, desc="Complete!")
522
- return final_zip, "\n".join(log)
523
- except Exception as e:
524
- import traceback
525
- return None, "\n".join(log) + f"\n\nERROR: {e}\n{traceback.format_exc()}"
526
- finally:
527
- # Clean up work directory (ZIP already copied out)
528
- if os.path.exists(work):
529
- shutil.rmtree(work, ignore_errors=True)
530
-
531
- # =============================================================================
532
- # UI
533
- # =============================================================================
534
- css = """
535
- .gradio-container { padding-top: 10px !important; }
536
- .compact-group { margin-bottom: 8px !important; }
537
- """
538
-
539
- with gr.Blocks(title="SD Dataset Automaker: Fancaps → Face Crop (YOLO) → Similarity (Siamese) → WD Tagger → ZIP") as demo:
540
- gr.Markdown("### SD Dataset Automaker: Fancaps → Face Crop (YOLO) → Similarity (Siamese) → WD Tagger → ZIP")
541
-
542
- with gr.Row():
543
- with gr.Column(scale=3):
544
- # Compact input group
545
- with gr.Group():
546
- with gr.Row():
547
- query = gr.Textbox(
548
- label="Anime / Movie / Fancaps URL",
549
- placeholder="'Cowboy Bebop' or paste URL",
550
- scale=3
551
- )
552
- char = gr.Textbox(label="Character (optional, appends to tags)", placeholder="spike_spiegel", scale=2)
553
-
554
- with gr.Row():
555
- ref_imgs = gr.Gallery(
556
- label="Reference Face Image(s)",
557
- columns=4,
558
- height=100,
559
- interactive=True,
560
- object_fit="scale-down",
561
- scale=3,
562
- )
563
- run_btn = gr.Button("Generate Dataset", variant="primary", size="lg", scale=1)
564
-
565
- # Hidden File input for MCP compatibility (Gallery $ref schema bug persists in Gradio 6.0.1)
566
- ref_files = gr.File(
567
- label="Reference Images (MCP)",
568
- file_count="multiple",
569
- file_types=["image"],
570
- visible=False,
571
- )
572
-
573
- # gr.Examples + gr.Gallery works in Gradio 5.46.0+ (PR #11787)
574
- gr.Examples(
575
- examples=[
576
- ["https://fancaps.net/anime/showimages.php?3092-Cowboy_Bebop", "spike_spiegel", EXAMPLES],
577
- ],
578
- inputs=[query, char, ref_imgs],
579
- label="Example (click to load)",
580
- )
581
-
582
- # Advanced settings in accordion
583
- with gr.Accordion("Advanced Settings", open=False):
584
- with gr.Row():
585
- max_img = gr.Slider(50, 500, 200, step=50, label="Max Downloads (frames to scrape)")
586
- thresh = gr.Slider(20, 60, 32, step=1, label="Face Similarity (lower=stricter)")
587
- with gr.Row():
588
- act_tag = gr.Textbox(label="Trigger Word (prepends to captions)", placeholder="e.g. sks_style", scale=2)
589
- anime_cb = gr.Checkbox(label="Anime", value=True, scale=1)
590
- movies_cb = gr.Checkbox(label="Movies", scale=1)
591
- tv_cb = gr.Checkbox(label="TV", scale=1)
592
-
593
- with gr.Column(scale=1):
594
- out_file = gr.File(label="Download ZIP", interactive=False)
595
- with gr.Accordion("Log", open=True):
596
- out_log = gr.Textbox(label="", lines=12, max_lines=50, show_label=False, autoscroll=False)
597
- gr.Markdown("*CPU: ~5-10 min/run*")
598
-
599
- def process(q, c, imgs, files, mi, th, at, an, mo, tv, prog=gr.Progress()):
600
- if not q:
601
- gr.Warning("Enter anime name or URL")
602
- return None, ""
603
- # Collect paths from Gallery (imgs) or File input (files)
604
- paths = []
605
- for item in (imgs or []):
606
- p = item[0] if isinstance(item, (list, tuple)) else item
607
- if p and os.path.exists(p): paths.append(p)
608
- if not paths and files:
609
- for f in (files if isinstance(files, list) else [files]):
610
- fp = f.name if hasattr(f, 'name') else str(f)
611
- if fp and os.path.exists(fp): paths.append(fp)
612
- if not paths:
613
- gr.Warning("Upload reference images or click 'Load Example'")
614
- return None, ""
615
- if 'fancaps.net' in q:
616
- gr.Info("Direct URL detected")
617
- else:
618
- gr.Info(f"Searching: {q}")
619
- zp, log = run(q, c, paths, mi, th, at, an, mo, tv, prog)
620
- if zp:
621
- gr.Info("Done!")
622
- return zp, log
623
-
624
- run_btn.click(process, [query, char, ref_imgs, ref_files, max_img, thresh, act_tag, anime_cb, movies_cb, tv_cb], [out_file, out_log])
625
-
626
- # MCP-only function (no Gallery - avoids $ref schema bug)
627
- def create_dataset(url: str, character: str, reference_images, max_images: int = 200, similarity_threshold: int = 32, trigger_tag: str = ""):
628
- """Create anime character dataset for LoRA/SD training.
629
-
630
- Args:
631
- url: Fancaps.net URL (e.g. https://fancaps.net/anime/showimages.php?3092-Cowboy_Bebop)
632
- character: Character name for tagging (e.g. spike_spiegel)
633
- reference_images: 1-5 face reference images of target character
634
- max_images: Maximum frames to download (50-500)
635
- similarity_threshold: Face matching threshold, lower=stricter (25-40, default 32)
636
- trigger_tag: Optional trigger word to prepend to captions
637
-
638
- Returns:
639
- Tuple of (zip_file_path, log_text)
640
- """
641
- paths = []
642
- if reference_images:
643
- for f in (reference_images if isinstance(reference_images, list) else [reference_images]):
644
- fp = f.name if hasattr(f, 'name') else str(f)
645
- if fp and os.path.exists(fp): paths.append(fp)
646
- if not paths:
647
- return None, "ERROR: No reference images provided. Upload 1-5 face images."
648
- if not url or 'fancaps.net' not in url:
649
- return None, "ERROR: Invalid URL. Provide a fancaps.net URL."
650
- zp, log = run(url, character, paths, max_images, similarity_threshold, trigger_tag, True, False, False, None)
651
- return zp, log
652
-
653
- # Create MCP interface (File input only - no Gallery)
654
- mcp_interface = gr.Interface(
655
- fn=create_dataset,
656
- inputs=[
657
- gr.Textbox(label="Fancaps URL"),
658
- gr.Textbox(label="Character Name"),
659
- gr.File(label="Reference Images", file_count="multiple", file_types=["image"]),
660
- gr.Slider(50, 500, 200, step=50, label="Max Images"),
661
- gr.Slider(25, 40, 32, step=1, label="Similarity Threshold"),
662
- gr.Textbox(label="Trigger Tag"),
663
- ],
664
- outputs=[gr.File(label="Dataset ZIP"), gr.Textbox(label="Log")],
665
- api_name="create_dataset",
666
- title="SD Dataset Automaker (MCP)",
667
- )
668
-
669
- def run_cli():
670
- """CLI mode with cloudscraper for Cloudflare bypass"""
671
- # Use cloudscraper for CLI (bypasses Cloudflare on local/residential IPs)
672
- init_session(use_cloudscraper=True)
673
-
674
- parser = argparse.ArgumentParser(description="SD Dataset Automaker - Anime character dataset generator")
675
- parser.add_argument("--title", "-t", required=True, help="Anime name or fancaps.net URL")
676
- parser.add_argument("--image", "-i", nargs="+", required=True, help="Reference face images (1-5)")
677
- parser.add_argument("--char", "-c", default="", help="Character name (optional, appends to tags)")
678
- parser.add_argument("--max", "-m", type=int, default=200, help="Max frames to scrape (default: 200)")
679
- parser.add_argument("--thresh", type=float, default=32.0, help="Face similarity threshold, lower=stricter (default: 32)")
680
- parser.add_argument("--tag", default="", help="Trigger word to prepend to captions")
681
- parser.add_argument("--anime", action="store_true", default=True, help="Search anime (default)")
682
- parser.add_argument("--movies", action="store_true", help="Search movies")
683
- parser.add_argument("--tv", action="store_true", help="Search TV")
684
- parser.add_argument("--output", "-o", default=".", help="Output directory (default: current)")
685
- args = parser.parse_args()
686
-
687
- # Validate images
688
- ref_imgs = [p for p in args.image if os.path.exists(p)]
689
- if not ref_imgs:
690
- print(f"ERROR: No valid reference images found: {args.image}")
691
- sys.exit(1)
692
-
693
- print(f"SD Dataset Automaker - CLI Mode")
694
- print(f" Title: {args.title}")
695
- print(f" Refs: {len(ref_imgs)} images")
696
- print(f" Char: {args.char or '(auto from title)'}")
697
- print()
698
-
699
- zp, log = run(
700
- query=args.title,
701
- char=args.char,
702
- examples=ref_imgs,
703
- max_img=args.max,
704
- thresh=args.thresh,
705
- act_tag=args.tag,
706
- anime=args.anime,
707
- movies=args.movies,
708
- tv=args.tv,
709
- cli_mode=True
710
- )
711
-
712
- if zp:
713
- # Copy to output dir
714
- out_path = os.path.join(args.output, os.path.basename(zp))
715
- shutil.copy(zp, out_path)
716
- print(f"\nSaved: {out_path}")
717
- else:
718
- print(f"\nFailed!")
719
- sys.exit(1)
720
-
721
- if __name__ == "__main__":
722
- # CLI mode if args provided, else Gradio UI
723
- if len(sys.argv) > 1:
724
- run_cli()
725
- else:
726
- # Gradio UI mode - combine main UI and MCP interface
727
- allowed_dir = os.path.dirname(os.path.abspath(__file__))
728
- app = gr.TabbedInterface(
729
- [demo, mcp_interface],
730
- ["Dataset Maker", "MCP API"],
731
- title="SD Dataset Automaker"
732
- )
733
- app.launch(
734
- server_name="0.0.0.0",
735
- server_port=7860,
736
- mcp_server=True,
737
- show_error=True,
738
- allowed_paths=[allowed_dir],
739
- )
 
1
+ #!/usr/bin/env python3
2
+ """SD Dataset Automaker - HF Space (CPU) - Anime character dataset generator for LoRA/fine-tuning."""
3
+
4
+ import warnings
5
+ warnings.filterwarnings('ignore', category=FutureWarning) # Suppress torch.cuda.amp spam
6
+ warnings.filterwarnings('ignore', category=UserWarning, message='.*trust_repo.*')
7
+
8
+ import os, re, shutil, zipfile, random, tempfile, argparse, sys
9
+ from urllib.parse import quote_plus
10
+ from collections import Counter
11
+ from pathlib import Path
12
+
13
+ from tqdm import tqdm
14
+ import gradio as gr
15
+ from bs4 import BeautifulSoup
16
+ import requests as req_lib
17
+ import time
18
+ import numpy as np
19
+ from PIL import Image
20
+ import torch
21
+ import torch.nn as nn
22
+ from torchvision import models, transforms
23
+ from sklearn.metrics.pairwise import pairwise_distances
24
+ import onnxruntime as rt
25
+ import pandas as pd
26
+ import huggingface_hub
27
+
28
+ # =============================================================================
29
+ # CONFIG
30
+ # =============================================================================
31
+ EXTS = ('.jpg', '.jpeg', '.png')
32
+ MODEL_DIR = Path(__file__).parent.resolve() # Ensure absolute path
33
+ YOLO_PATH = MODEL_DIR / "yolov5s_anime.pt"
34
+ SIM_PATH = MODEL_DIR / "similarity.pt"
35
+ EXAMPLES = [str(MODEL_DIR / f"from_url_spike_spiegel{i}.jpg") for i in range(1, 4)] # absolute paths for gr.Examples
36
+ WD_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
37
+ TAG_THRESH, BLACKLIST = 0.35, ["bangs", "breasts", "multicolored hair", "gradient hair", "two-tone hair", "virtual youtuber"]
38
+ FACE_CONF, FACE_IOU, MIN_FACE, CROP_PAD = 0.5, 0.5, 35, 0.2
39
+ SIM_THRESH, BATCH_SZ, FACE_SZ = 32, 16, 224
40
+
41
+ # =============================================================================
42
+ # UTILS
43
+ # =============================================================================
44
+ sanitize = lambda s: re.sub(r'[^\w.-]', '', s.replace(" ", "_"))
45
+ get_imgs = lambda d: sorted([os.path.join(r,f) for r,_,fs in os.walk(d) for f in fs if f.lower().endswith(EXTS)])
46
+ valid_img = lambda p: (lambda i: i.load() or True)(Image.open(p)) if os.path.exists(p) else False
47
+
48
+ # HTTP session - mode depends on environment
49
+ # CLI (local Windows): cloudscraper bypasses Cloudflare
50
+ # HF Spaces: plain requests (cloudscraper fingerprint gets blocked on datacenter IPs)
51
+ def init_session(use_cloudscraper=False):
52
+ global SESSION, HTTP_CLIENT
53
+ if use_cloudscraper:
54
+ try:
55
+ import cloudscraper
56
+ SESSION = cloudscraper.create_scraper()
57
+ HTTP_CLIENT = "cloudscraper"
58
+ return
59
+ except ImportError:
60
+ pass # fallback to requests
61
+ SESSION = req_lib.Session()
62
+ SESSION.headers.update({
63
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
64
+ 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8',
65
+ 'Accept-Language': 'en-US,en;q=0.5',
66
+ 'Accept-Encoding': 'gzip, deflate, br',
67
+ 'DNT': '1',
68
+ 'Connection': 'keep-alive',
69
+ 'Upgrade-Insecure-Requests': '1',
70
+ 'Sec-Fetch-Dest': 'document',
71
+ 'Sec-Fetch-Mode': 'navigate',
72
+ 'Sec-Fetch-Site': 'none',
73
+ 'Sec-Fetch-User': '?1',
74
+ 'Cache-Control': 'max-age=0',
75
+ 'Referer': 'https://fancaps.net/',
76
+ })
77
+ HTTP_CLIENT = "requests/browser-headers"
78
+
79
+ # Default: plain requests (for HF Spaces import)
80
+ SESSION = None
81
+ HTTP_CLIENT = None
82
+ init_session(use_cloudscraper=False)
83
+
84
+ # =============================================================================
85
+ # SCRAPING
86
+ # =============================================================================
87
+ def search_fancaps(prompt, anime=True, movies=False, tv=False, log_fn=None):
88
+ L = log_fn or print
89
+ url = f"https://fancaps.net/search.php?q={quote_plus(prompt)}&submit=Submit"
90
+ if anime: url += "&animeCB=Anime"
91
+ if movies: url += "&MoviesCB=Movies"
92
+ if tv: url += "&TVCB=TV"
93
+ L(f" URL: {url}")
94
+ try:
95
+ resp = SESSION.get(url, timeout=30)
96
+ L(f" Status: {resp.status_code}, Size: {len(resp.content)} bytes")
97
+ # Log key headers for debugging
98
+ cf_ray = resp.headers.get('cf-ray', 'none')
99
+ server = resp.headers.get('server', 'unknown')
100
+ L(f" Server: {server}, CF-Ray: {cf_ray}")
101
+ if resp.status_code != 200:
102
+ L(f" ERROR: HTTP {resp.status_code}")
103
+ # Show snippet of response to understand the block reason
104
+ content_snippet = resp.text[:500].replace('\n', ' ').strip()
105
+ L(f" Response snippet: {content_snippet[:200]}...")
106
+ return {}
107
+ soup = BeautifulSoup(resp.content, "html.parser")
108
+ # Debug: check if we got Cloudflare challenge
109
+ title = soup.title.string if soup.title else "No title"
110
+ L(f" Page title: {title[:50]}")
111
+ if "cloudflare" in title.lower() or "challenge" in title.lower():
112
+ L(" ERROR: Cloudflare challenge detected!")
113
+ return {}
114
+ except Exception as e:
115
+ L(f" ERROR: {type(e).__name__}: {e}")
116
+ return {}
117
+ results, cnt = {}, 1
118
+ divs = soup.find_all("div", class_="single_post_content")
119
+ L(f" Found {len(divs)} content divs")
120
+ for div in divs:
121
+ if not div.find('h2'): continue
122
+ for h2 in div.find_all('h2'):
123
+ title = h2.get_text(strip=True).lower()
124
+ cat = 'Movies' if 'movie' in title else 'TV' if 'tv' in title else 'Anime' if 'anime' in title else None
125
+ if not cat: continue
126
+ table = h2.find_next('table')
127
+ if not table: continue
128
+ results.setdefault(cat, [])
129
+ for h4 in table.find_all('h4'):
130
+ a = h4.find('a')
131
+ if a and a.get('href'):
132
+ results[cat].append((a.get_text(strip=True), a['href'], cnt)); cnt += 1
133
+ break
134
+ L(f" Parsed results: {sum(len(v) for v in results.values())} items in {list(results.keys())}")
135
+ return results
136
+
137
+ def get_episodes(url, log_fn=None):
138
+ L = log_fn or (lambda x: None)
139
+ links, page = [], 1
140
+ while True:
141
+ try:
142
+ resp = SESSION.get(f"{url}&page={page}", timeout=20)
143
+ L(f" get_episodes page {page}: status={resp.status_code}")
144
+ if resp.status_code != 200:
145
+ L(f" ERROR: {resp.text[:150]}...")
146
+ break
147
+ soup = BeautifulSoup(resp.content, "html.parser")
148
+ except Exception as e:
149
+ L(f" get_episodes ERROR: {type(e).__name__}: {e}")
150
+ break
151
+ btns = soup.find_all('a', class_='btn btn-block')
152
+ if not btns:
153
+ L(f" No episode buttons on page {page}")
154
+ break
155
+ links.extend([("https://fancaps.net" + b['href'] if b['href'].startswith('/') else b['href']) for b in btns if b.get('href')])
156
+ L(f" Found {len(btns)} episodes on page {page}, total: {len(links)}")
157
+ page += 1
158
+ return links or [url]
159
+
160
+ def get_frame_names(url, log_fn=None):
161
+ L = log_fn or (lambda x: None)
162
+ names, page = [], 1
163
+ while True:
164
+ try:
165
+ resp = SESSION.get(f"{url}&page={page}", timeout=20)
166
+ if resp.status_code != 200:
167
+ L(f" get_frames page {page}: status={resp.status_code}")
168
+ break
169
+ soup = BeautifulSoup(resp.content, "html.parser")
170
+ except Exception as e:
171
+ L(f" get_frames ERROR: {type(e).__name__}: {e}")
172
+ break
173
+ imgs = soup.find_all('img', class_='imageFade')
174
+ if not imgs:
175
+ if page == 1: L(f" No images found on first page!")
176
+ break
177
+ names.extend([s.split('/')[-1] for i in imgs if (s := i.get('src')) and s.split('/')[-1] not in names])
178
+ pager = soup.select_one('ul.pagination li:last-child a')
179
+ if not pager or pager.get('href') in ['#', None]: break
180
+ page += 1
181
+ L(f" Total frame names: {len(names)}")
182
+ return names
183
+
184
+ def download(url, folder, name, timeout=10, retries=3):
185
+ """Download single image with retry - returns (success, status_code)."""
186
+ fp = os.path.join(folder, name)
187
+ if os.path.exists(fp): return True, 200
188
+ for attempt in range(retries):
189
+ try:
190
+ r = SESSION.get(url, stream=True, timeout=timeout)
191
+ if r.status_code == 200:
192
+ with open(fp, 'wb') as f:
193
+ for chunk in r.iter_content(16384):
194
+ if chunk: f.write(chunk)
195
+ return True, 200
196
+ if r.status_code == 429: # Rate limit - don't retry immediately
197
+ return False, 429
198
+ # Other errors - retry
199
+ except:
200
+ pass
201
+ if attempt < retries - 1:
202
+ time.sleep(1)
203
+ return False, None
204
+
205
+ def scrape(name, link, save_dir, max_imgs, progress=None, log_fn=None):
206
+ L = log_fn or print
207
+ url, folder = "https://fancaps.net" + link, os.path.join(save_dir, sanitize(name))
208
+ os.makedirs(folder, exist_ok=True)
209
+ section = 'movie' if '/movies/' in link else 'anime' if '/anime/' in link else 'tv'
210
+ L(f" [2/8] Scraping: {url}")
211
+ L(f" Section: {section}, max: {max_imgs}")
212
+ consecutive_429 = 0
213
+ max_429 = 3 # Abort after 3 consecutive 429s
214
+
215
+ if section == 'movie':
216
+ names = get_frame_names(url, log_fn=L)
217
+ L(f" Movie frames: {len(names)}")
218
+ sampled = random.sample(names, min(max_imgs, len(names))) if names else []
219
+ downloaded = 0
220
+ for i, n in enumerate(sampled):
221
+ if consecutive_429 >= max_429:
222
+ L(f" Aborting: {consecutive_429} consecutive 429s")
223
+ break
224
+ if i > 0: time.sleep(random.uniform(0.3, 0.8)) # Faster delay
225
+ try:
226
+ if progress and len(sampled) > 0: progress((i+1)/len(sampled), desc=f"Downloading {name[:20]}")
227
+ except: pass
228
+ success, status = download(f"https://cdni.fancaps.net/file/fancaps-{section}images/{n}", folder, n)
229
+ if success:
230
+ downloaded += 1
231
+ consecutive_429 = 0
232
+ elif status == 429:
233
+ consecutive_429 += 1
234
+ cooldown = 30 * consecutive_429
235
+ L(f" 429 rate limit ({consecutive_429}/{max_429}), cooling {cooldown}s...")
236
+ time.sleep(cooldown)
237
+ else:
238
+ consecutive_429 = 0
239
+ L(f" Downloaded: {downloaded}/{len(sampled)}")
240
+ else:
241
+ L(f" Fetching episodes...")
242
+ eps = get_episodes(url, log_fn=L)
243
+ L(f" Episodes: {len(eps)}")
244
+ total = 0
245
+ per_ep = max(1, max_imgs // len(eps)) if eps else max_imgs
246
+ for i, ep in enumerate(eps):
247
+ if total >= max_imgs or consecutive_429 >= max_429: break
248
+ names = get_frame_names(ep, log_fn=L)
249
+ if not names: continue
250
+ ep_dir = os.path.join(folder, f"Ep{i+1}")
251
+ os.makedirs(ep_dir, exist_ok=True)
252
+ sampled = random.sample(names, min(per_ep, len(names), max_imgs - total))
253
+ for j, n in enumerate(sampled):
254
+ if consecutive_429 >= max_429: break
255
+ if j > 0: time.sleep(random.uniform(0.3, 0.8)) # Faster delay
256
+ try:
257
+ if progress and max_imgs > 0: progress(total/max_imgs, desc=f"Ep{i+1}")
258
+ except: pass # Gradio progress can fail in some contexts
259
+ success, status = download(f"https://cdni.fancaps.net/file/fancaps-{section}images/{n}", ep_dir, n)
260
+ if success:
261
+ total += 1
262
+ consecutive_429 = 0
263
+ elif status == 429:
264
+ consecutive_429 += 1
265
+ cooldown = 30 * consecutive_429
266
+ L(f" 429 rate limit ({consecutive_429}/{max_429}), cooling {cooldown}s...")
267
+ time.sleep(cooldown)
268
+ else:
269
+ consecutive_429 = 0
270
+ L(f" Total downloaded: {total}")
271
+
272
+ # =============================================================================
273
+ # ML MODELS (cached)
274
+ # =============================================================================
275
+ _models = {}
276
+
277
+ def get_yolo():
278
+ if 'yolo' not in _models:
279
+ _models['yolo'] = torch.hub.load('ultralytics/yolov5', 'custom', path=str(YOLO_PATH), force_reload=False, verbose=False)
280
+ _models['yolo'].conf, _models['yolo'].iou = FACE_CONF, FACE_IOU
281
+ return _models['yolo']
282
+
283
+ def get_sim():
284
+ if 'sim' not in _models:
285
+ class SiameseNetwork(nn.Module):
286
+ def __init__(self):
287
+ super().__init__()
288
+ self.base_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
289
+ def forward(self, x):
290
+ return self.base_model(x) # 1000-class output (trained this way)
291
+ m = SiameseNetwork()
292
+ # Keep on CPU for consistent distance values across devices
293
+ m.load_state_dict(torch.load(str(SIM_PATH), map_location="cpu", weights_only=True))
294
+ m.eval()
295
+ _models['sim'] = m
296
+ return _models['sim']
297
+
298
+ def get_tagger():
299
+ if 'tag' not in _models:
300
+ mp, cp = huggingface_hub.hf_hub_download(WD_REPO, "model.onnx"), huggingface_hub.hf_hub_download(WD_REPO, "selected_tags.csv")
301
+ tags = [str(x).replace('_', ' ') for x in pd.read_csv(cp)['name'].tolist()]
302
+ sess = rt.InferenceSession(mp, providers=['CPUExecutionProvider'])
303
+ _models['tag'] = (sess, tags, sess.get_inputs()[0].shape[1])
304
+ return _models['tag']
305
+
306
+ # =============================================================================
307
+ # PROCESSING
308
+ # =============================================================================
309
+ def dedup(paths, thresh=0.98):
310
+ if not paths: return [], []
311
+ m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1); m.fc = nn.Identity(); m.eval()
312
+ tf = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([.485,.456,.406],[.229,.224,.225])])
313
+ emb, valid = [], []
314
+ with torch.no_grad():
315
+ for i in range(0, len(paths), 32):
316
+ batch = [(tf(Image.open(p).convert('RGB')), p) for p in paths[i:i+32] if valid_img(p)]
317
+ if batch:
318
+ x = torch.stack([b[0] for b in batch])
319
+ emb.append(m(x).numpy()); valid.extend([b[1] for b in batch])
320
+ del m
321
+ if not emb: return [], []
322
+ emb = np.vstack(emb); emb /= np.linalg.norm(emb, axis=1, keepdims=True).clip(1e-8)
323
+ sim = emb @ emb.T; np.fill_diagonal(sim, 0)
324
+ keep, drop = [], set()
325
+ for i in range(len(valid)):
326
+ if i not in drop: keep.append(valid[i]); drop.update(j for j in np.where(sim[i] > thresh)[0] if j > i)
327
+ return keep, [valid[i] for i in drop]
328
+
329
+ def detect_faces(paths, out_dir):
330
+ yolo = get_yolo(); os.makedirs(out_dir, exist_ok=True); cnt = 0
331
+ for p in paths:
332
+ try:
333
+ img = Image.open(p).convert('RGB'); w, h = img.size
334
+ for j, det in enumerate(yolo(img, size=640).xyxy[0].cpu().numpy()):
335
+ x1, y1, x2, y2, conf, _ = det
336
+ bw, bh = x2-x1, y2-y1
337
+ x1, y1, x2, y2 = max(0, x1-bw*CROP_PAD), max(0, y1-bh*CROP_PAD), min(w, x2+bw*CROP_PAD), min(h, y2+bh*CROP_PAD)
338
+ if min(x2-x1, y2-y1) >= MIN_FACE:
339
+ img.crop((int(x1), int(y1), int(x2), int(y2))).save(os.path.join(out_dir, f"{Path(p).stem}-{j+1}-{conf:.2f}.jpg"), quality=95)
340
+ cnt += 1
341
+ except: pass
342
+ return cnt
343
+
344
+ def face_emb(paths):
345
+ if not paths: return np.array([]), []
346
+ m = get_sim() # Always on CPU for consistent distances
347
+ def pad(img):
348
+ t, w, h = FACE_SZ, img.size[0], img.size[1]; r = w/h
349
+ nw, nh = (t, int(t/r)) if r > 1 else (int(t*r), t)
350
+ img = img.resize((nw, nh), Image.BICUBIC)
351
+ out = Image.new('RGB', (t, t), (0,0,0)); out.paste(img, ((t-nw)//2, (t-nh)//2)); return out
352
+ tf = transforms.Compose([lambda x: pad(x.convert('RGB') if x.mode == 'RGBA' else x), transforms.ToTensor()])
353
+ emb, valid = [], []
354
+ with torch.no_grad():
355
+ for i in range(0, len(paths), BATCH_SZ):
356
+ batch = [(tf(Image.open(p)), p) for p in paths[i:i+BATCH_SZ] if valid_img(p)]
357
+ if batch:
358
+ x = torch.stack([b[0] for b in batch]) # CPU tensor
359
+ emb.append(m(x).numpy())
360
+ valid.extend([b[1] for b in batch])
361
+ return (np.vstack(emb), valid) if emb else (np.array([]), [])
362
+
363
+ def tag(path, act_tag="", char_tag=""):
364
+ sess, tags, sz = get_tagger()
365
+ img = Image.open(path).convert('RGB'); w, h = img.size
366
+ s = min(sz/w, sz/h); nw, nh = int(w*s), int(h*s)
367
+ pad = Image.new('RGB', (sz, sz), (255,255,255)); pad.paste(img.resize((nw, nh), Image.BICUBIC), ((sz-nw)//2, (sz-nh)//2))
368
+ probs = sess.run(None, {sess.get_inputs()[0].name: np.expand_dims(np.array(pad).astype(np.float32)[:,:,::-1], 0)})[0][0]
369
+ found = [tags[i] for i, p in enumerate(probs) if p > TAG_THRESH and tags[i] not in BLACKLIST]
370
+ # Prepend activation tag and character tag if provided
371
+ prefix = []
372
+ if act_tag: prefix.append(act_tag); found = [t for t in found if t != act_tag]
373
+ if char_tag: prefix.append(char_tag.replace("_", " ")); found = [t for t in found if t != char_tag.replace("_", " ")]
374
+ return prefix + found
375
+
376
+ # =============================================================================
377
+ # PIPELINE
378
+ # =============================================================================
379
+ def parse_direct_url(url):
380
+ """Parse direct fancaps URL, extract show name and relative link. Returns (name, link) or (None, None)."""
381
+ # Match patterns: showimages.php?ID-Name, MovieImages.php?movieid=ID&name=Name, episodeimages.php?ID-Name
382
+ patterns = [
383
+ r'fancaps\.net/anime/showimages\.php\?(\d+)-([^&/]+)', # anime show
384
+ r'fancaps\.net/tv/showimages\.php\?(\d+)-([^&/]+)', # tv show
385
+ r'fancaps\.net/movies/MovieImages\.php\?.*?movieid=(\d+)', # movie
386
+ ]
387
+ for pat in patterns:
388
+ m = re.search(pat, url)
389
+ if m:
390
+ if 'anime' in url: section = 'anime'
391
+ elif 'movies' in url: section = 'movies'
392
+ else: section = 'tv'
393
+ # Extract name from URL (replace underscores with spaces)
394
+ name = m.group(2).replace('_', ' ') if len(m.groups()) > 1 else f"Show_{m.group(1)}"
395
+ # Build relative link (what scrape() expects)
396
+ if section == 'movies':
397
+ link = f"/movies/MovieImages.php?movieid={m.group(1)}"
398
+ else:
399
+ link = f"/{section}/showimages.php?{m.group(1)}-{m.group(2) if len(m.groups()) > 1 else ''}"
400
+ return name, link
401
+ return None, None
402
+
403
+ def run(query, char, examples, max_img, thresh, act_tag, anime, movies, tv, progress=None, cli_mode=False):
404
+ log = []
405
+ def L(m): log.append(m); print(m)
406
+ def prog(val, desc=""):
407
+ if progress and not cli_mode: progress(val, desc=desc)
408
+
409
+ work = tempfile.mkdtemp(prefix="ds_")
410
+ dirs = {k: os.path.join(work, f"{i}_{k}") for i, k in enumerate(['scrapped','filtered','faces','ex_faces','similar','results'], 1)}
411
+ for d in dirs.values(): os.makedirs(d, exist_ok=True)
412
+ final_zip = None # Track ZIP for cleanup
413
+
414
+ try:
415
+ L(f"HTTP client: {HTTP_CLIENT}")
416
+ t0 = time.time()
417
+
418
+ # Check if query is a direct fancaps URL (bypasses search, works on HF Spaces)
419
+ if 'fancaps.net' in query and ('showimages.php' in query or 'MovieImages.php' in query):
420
+ L(f"[1/8] Direct URL mode")
421
+ name, link = parse_direct_url(query)
422
+ if not link:
423
+ return None, "\n".join(log) + "\n\nCouldn't parse URL!"
424
+ item = (name, link, 1)
425
+ L(f" Parsed: {name}")
426
+ else:
427
+ L(f"[1/8] Search: {query}")
428
+ prog(0.05, desc="Searching...")
429
+ res = search_fancaps(query, anime, movies, tv, log_fn=L)
430
+ if not res:
431
+ return None, "\n".join(log) + "\n\nSearch blocked! Use direct fancaps URL."
432
+ item = next((items[0] for items in res.values() if items), None)
433
+ if not item: return None, "No results!"
434
+
435
+ show_name = item[0]
436
+ if not char: char = sanitize(show_name)
437
+ t1 = time.time(); L(f" Found: {show_name} ({t1-t0:.0f}s)"); prog(0.1, desc="Downloading...")
438
+
439
+ # [2/8] Scrape
440
+ scrape(item[0], item[1], dirs['scrapped'], max_img, progress if not cli_mode else None, log_fn=L)
441
+ imgs = get_imgs(dirs['scrapped'])
442
+ t2 = time.time(); L(f"[2/8] Downloaded: {len(imgs)} ({t2-t1:.0f}s)")
443
+ if not imgs: return None, "No images downloaded!"
444
+
445
+ # [3/8] Dedup
446
+ prog(0.3, desc="Dedup...")
447
+ imgs = [p for p in imgs if valid_img(p)]
448
+ kept, rm = dedup(imgs)
449
+ for p in kept: shutil.copy(p, os.path.join(dirs['filtered'], os.path.basename(p)))
450
+ t3 = time.time(); L(f"[3/8] Dedup: {len(kept)} kept, -{len(rm)} ({t3-t2:.0f}s)")
451
+
452
+ # [4/8] Detect faces
453
+ prog(0.4, desc="Faces...")
454
+ n = detect_faces(get_imgs(dirs['filtered']), dirs['faces'])
455
+ t4 = time.time(); L(f"[4/8] Faces: {n} ({t4-t3:.0f}s)")
456
+ if n == 0: return None, "No faces detected!"
457
+
458
+ # [5/8] Process examples
459
+ prog(0.5, desc="Examples...")
460
+ ex_paths = [p for p in (examples or []) if p and os.path.exists(p)]
461
+ if not ex_paths: ex_paths = [p for p in EXAMPLES if os.path.exists(p)]
462
+ if not ex_paths: return None, "No example images!"
463
+ n_ex = detect_faces(ex_paths, dirs['ex_faces'])
464
+ t5 = time.time(); L(f"[5/8] Examples: {len(ex_paths)} imgs -> {n_ex} faces ({t5-t4:.0f}s)")
465
+ if n_ex == 0: return None, "No faces in examples!"
466
+
467
+ # [6/8] Match
468
+ prog(0.6, desc="Matching...")
469
+ f_emb, f_valid = face_emb(get_imgs(dirs['faces']))
470
+ e_emb, _ = face_emb(get_imgs(dirs['ex_faces']))
471
+ dists = pairwise_distances(f_emb, e_emb, metric='euclidean').min(axis=1)
472
+ similar_idx = np.where(dists < thresh)[0]
473
+ similar = [f_valid[i] for i in similar_idx]
474
+ similar_dists = dists[similar_idx]
475
+ t6 = time.time()
476
+ L(f"[6/8] Matches: {len(similar)} (thresh={thresh}) ({t6-t5:.0f}s)")
477
+ if len(similar_dists) > 0:
478
+ L(f" Distances: min={similar_dists.min():.1f}, max={similar_dists.max():.1f}, mean={similar_dists.mean():.1f}")
479
+ if not similar: return None, f"No matches! Try threshold > {thresh}"
480
+
481
+ # [7/8] Get originals
482
+ prog(0.7, desc="Collect...")
483
+ origs = set()
484
+ orig_to_dist = {}
485
+ for i, fp in enumerate(similar):
486
+ parts = os.path.basename(fp).rsplit('-', 2)
487
+ base = parts[0] if len(parts) >= 3 else Path(fp).stem
488
+ for ext in EXTS:
489
+ op = os.path.join(dirs['filtered'], base + ext)
490
+ if os.path.exists(op):
491
+ origs.add(op)
492
+ orig_to_dist[os.path.basename(op)] = similar_dists[i]
493
+ break
494
+ res_dir = os.path.join(work, f"results_{sanitize(char)}")
495
+ os.makedirs(res_dir, exist_ok=True)
496
+ for p in origs: shutil.copy(p, os.path.join(res_dir, os.path.basename(p)))
497
+ t7 = time.time(); L(f"[7/8] Collected: {len(origs)} ({t7-t6:.0f}s)")
498
+
499
+ # [8/8] Tag
500
+ prog(0.8, desc="Tagging...")
501
+ char_tag = char if char != sanitize(show_name) else ""
502
+ for p in get_imgs(res_dir):
503
+ tags = tag(p, act_tag, char_tag)
504
+ with open(os.path.splitext(p)[0] + ".txt", 'w') as f: f.write(", ".join(tags))
505
+ t8 = time.time(); L(f"[8/8] Tagged: {len(origs)} ({t8-t7:.0f}s)")
506
+
507
+ # Log each image with distance
508
+ L(f"\nResults (distance to ref):")
509
+ for name, d in sorted(orig_to_dist.items(), key=lambda x: x[1]):
510
+ L(f" {name}: {d:.1f}")
511
+
512
+ # Zip
513
+ prog(0.95, desc="Zipping...")
514
+ zp = os.path.join(work, f"{sanitize(char)}_dataset.zip")
515
+ with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as z:
516
+ for p in get_imgs(res_dir) + [os.path.splitext(p)[0]+".txt" for p in get_imgs(res_dir)]:
517
+ if os.path.exists(p): z.write(p, os.path.basename(p))
518
+ # Copy ZIP to persistent temp location (Gradio needs file to exist after return)
519
+ final_zip = tempfile.NamedTemporaryFile(delete=False, suffix=".zip", prefix=f"{sanitize(char)}_").name
520
+ shutil.copy(zp, final_zip)
521
+ L(f"\nDone! {len(origs)} images, total {t8-t0:.0f}s"); prog(1.0, desc="Complete!")
522
+ return final_zip, "\n".join(log)
523
+ except Exception as e:
524
+ import traceback
525
+ return None, "\n".join(log) + f"\n\nERROR: {e}\n{traceback.format_exc()}"
526
+ finally:
527
+ # Clean up work directory (ZIP already copied out)
528
+ if os.path.exists(work):
529
+ shutil.rmtree(work, ignore_errors=True)
530
+
531
+ # =============================================================================
532
+ # UI
533
+ # =============================================================================
534
+ css = """
535
+ .gradio-container { padding-top: 10px !important; }
536
+ .compact-group { margin-bottom: 8px !important; }
537
+ """
538
+
539
+ with gr.Blocks(title="SD Dataset Automaker: Fancaps → Face Crop (YOLO) → Similarity (Siamese) → WD Tagger → ZIP") as demo:
540
+ gr.Markdown("### SD Dataset Automaker: Fancaps → Face Crop (YOLO) → Similarity (Siamese) → WD Tagger → ZIP")
541
+
542
+ with gr.Row():
543
+ with gr.Column(scale=3):
544
+ # Compact input group
545
+ with gr.Group():
546
+ with gr.Row():
547
+ query = gr.Textbox(
548
+ label="Anime / Movie / Fancaps URL",
549
+ placeholder="'Cowboy Bebop' or paste URL",
550
+ scale=3
551
+ )
552
+ char = gr.Textbox(label="Character (optional, appends to tags)", placeholder="spike_spiegel", scale=2)
553
+
554
+ with gr.Row():
555
+ ref_imgs = gr.Gallery(
556
+ label="Reference Face Image(s)",
557
+ columns=4,
558
+ height=100,
559
+ interactive=True,
560
+ object_fit="scale-down",
561
+ scale=3,
562
+ )
563
+ run_btn = gr.Button("Generate Dataset", variant="primary", size="lg", scale=1)
564
+
565
+ # Hidden File input for MCP compatibility (Gallery $ref schema bug persists in Gradio 6.0.1)
566
+ ref_files = gr.File(
567
+ label="Reference Images (MCP)",
568
+ file_count="multiple",
569
+ file_types=["image"],
570
+ visible=False,
571
+ )
572
+
573
+ # gr.Examples + gr.Gallery works in Gradio 5.46.0+ (PR #11787)
574
+ gr.Examples(
575
+ examples=[
576
+ ["https://fancaps.net/anime/showimages.php?3092-Cowboy_Bebop", "spike_spiegel", EXAMPLES],
577
+ ],
578
+ inputs=[query, char, ref_imgs],
579
+ label="Example (click to load)",
580
+ )
581
+
582
+ # Advanced settings in accordion
583
+ with gr.Accordion("Advanced Settings", open=False):
584
+ with gr.Row():
585
+ max_img = gr.Slider(50, 500, 200, step=50, label="Max Downloads (frames to scrape)")
586
+ thresh = gr.Slider(20, 60, 32, step=1, label="Face Similarity (lower=stricter)")
587
+ with gr.Row():
588
+ act_tag = gr.Textbox(label="Trigger Word (prepends to captions)", placeholder="e.g. sks_style", scale=2)
589
+ anime_cb = gr.Checkbox(label="Anime", value=True, scale=1)
590
+ movies_cb = gr.Checkbox(label="Movies", scale=1)
591
+ tv_cb = gr.Checkbox(label="TV", scale=1)
592
+
593
+ with gr.Column(scale=1):
594
+ out_file = gr.File(label="Download ZIP", interactive=False)
595
+ with gr.Accordion("Log", open=True):
596
+ out_log = gr.Textbox(label="", lines=12, max_lines=50, show_label=False, autoscroll=False)
597
+ gr.Markdown("*CPU: ~5-10 min/run*")
598
+
599
+ def process(q, c, imgs, files, mi, th, at, an, mo, tv, prog=gr.Progress()):
600
+ if not q:
601
+ gr.Warning("Enter anime name or URL")
602
+ return None, ""
603
+ # Collect paths from Gallery (imgs) or File input (files)
604
+ paths = []
605
+ for item in (imgs or []):
606
+ p = item[0] if isinstance(item, (list, tuple)) else item
607
+ if p and os.path.exists(p): paths.append(p)
608
+ if not paths and files:
609
+ for f in (files if isinstance(files, list) else [files]):
610
+ fp = f.name if hasattr(f, 'name') else str(f)
611
+ if fp and os.path.exists(fp): paths.append(fp)
612
+ if not paths:
613
+ gr.Warning("Upload reference images or click 'Load Example'")
614
+ return None, ""
615
+ if 'fancaps.net' in q:
616
+ gr.Info("Direct URL detected")
617
+ else:
618
+ gr.Info(f"Searching: {q}")
619
+ zp, log = run(q, c, paths, mi, th, at, an, mo, tv, prog)
620
+ if zp:
621
+ gr.Info("Done!")
622
+ return zp, log
623
+
624
+ run_btn.click(process, [query, char, ref_imgs, ref_files, max_img, thresh, act_tag, anime_cb, movies_cb, tv_cb], [out_file, out_log])
625
+
626
+ def run_cli():
627
+ """CLI mode with cloudscraper for Cloudflare bypass"""
628
+ # Use cloudscraper for CLI (bypasses Cloudflare on local/residential IPs)
629
+ init_session(use_cloudscraper=True)
630
+
631
+ parser = argparse.ArgumentParser(description="SD Dataset Automaker - Anime character dataset generator")
632
+ parser.add_argument("--title", "-t", required=True, help="Anime name or fancaps.net URL")
633
+ parser.add_argument("--image", "-i", nargs="+", required=True, help="Reference face images (1-5)")
634
+ parser.add_argument("--char", "-c", default="", help="Character name (optional, appends to tags)")
635
+ parser.add_argument("--max", "-m", type=int, default=200, help="Max frames to scrape (default: 200)")
636
+ parser.add_argument("--thresh", type=float, default=32.0, help="Face similarity threshold, lower=stricter (default: 32)")
637
+ parser.add_argument("--tag", default="", help="Trigger word to prepend to captions")
638
+ parser.add_argument("--anime", action="store_true", default=True, help="Search anime (default)")
639
+ parser.add_argument("--movies", action="store_true", help="Search movies")
640
+ parser.add_argument("--tv", action="store_true", help="Search TV")
641
+ parser.add_argument("--output", "-o", default=".", help="Output directory (default: current)")
642
+ args = parser.parse_args()
643
+
644
+ # Validate images
645
+ ref_imgs = [p for p in args.image if os.path.exists(p)]
646
+ if not ref_imgs:
647
+ print(f"ERROR: No valid reference images found: {args.image}")
648
+ sys.exit(1)
649
+
650
+ print(f"SD Dataset Automaker - CLI Mode")
651
+ print(f" Title: {args.title}")
652
+ print(f" Refs: {len(ref_imgs)} images")
653
+ print(f" Char: {args.char or '(auto from title)'}")
654
+ print()
655
+
656
+ zp, log = run(
657
+ query=args.title,
658
+ char=args.char,
659
+ examples=ref_imgs,
660
+ max_img=args.max,
661
+ thresh=args.thresh,
662
+ act_tag=args.tag,
663
+ anime=args.anime,
664
+ movies=args.movies,
665
+ tv=args.tv,
666
+ cli_mode=True
667
+ )
668
+
669
+ if zp:
670
+ # Copy to output dir
671
+ out_path = os.path.join(args.output, os.path.basename(zp))
672
+ shutil.copy(zp, out_path)
673
+ print(f"\nSaved: {out_path}")
674
+ else:
675
+ print(f"\nFailed!")
676
+ sys.exit(1)
677
+
678
+ if __name__ == "__main__":
679
+ # CLI mode if args provided, else Gradio UI
680
+ if len(sys.argv) > 1:
681
+ run_cli()
682
+ else:
683
+ # Gradio UI mode
684
+ allowed_dir = os.path.dirname(os.path.abspath(__file__))
685
+ demo.launch(
686
+ server_name="0.0.0.0",
687
+ server_port=7860,
688
+ mcp_server=True,
689
+ show_error=True,
690
+ allowed_paths=[allowed_dir],
691
+ css=css,
692
+ )