Commit
·
f40725d
1
Parent(s):
963b208
Add simple colorization fallback using LAB color space when model fails to load
Browse files- app/main_fastai.py +43 -5
- requirements.txt +3 -1
app/main_fastai.py
CHANGED
|
@@ -28,6 +28,8 @@ from PIL import Image
|
|
| 28 |
import torch
|
| 29 |
import uvicorn
|
| 30 |
import gradio as gr
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# FastAI imports
|
| 33 |
from fastai.vision.all import *
|
|
@@ -211,6 +213,40 @@ async def health_check():
|
|
| 211 |
response["model_error"] = model_load_error
|
| 212 |
return response
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
def colorize_pil(image: Image.Image) -> Image.Image:
|
| 215 |
"""Run model prediction and return colorized image"""
|
| 216 |
# Try FastAI first
|
|
@@ -251,7 +287,9 @@ def colorize_pil(image: Image.Image) -> Image.Image:
|
|
| 251 |
return pytorch_colorizer.colorize(image)
|
| 252 |
|
| 253 |
else:
|
| 254 |
-
|
|
|
|
|
|
|
| 255 |
|
| 256 |
@app.post("/colorize")
|
| 257 |
async def colorize_api(
|
|
@@ -262,8 +300,9 @@ async def colorize_api(
|
|
| 262 |
Upload a black & white image -> returns colorized image.
|
| 263 |
Requires Firebase authentication unless DISABLE_AUTH=true
|
| 264 |
"""
|
| 265 |
-
|
| 266 |
-
|
|
|
|
| 267 |
|
| 268 |
if not file.content_type or not file.content_type.startswith("image/"):
|
| 269 |
raise HTTPException(status_code=400, detail="File must be an image")
|
|
@@ -299,8 +338,7 @@ def gradio_colorize(image):
|
|
| 299 |
if image is None:
|
| 300 |
return None
|
| 301 |
try:
|
| 302 |
-
|
| 303 |
-
return None
|
| 304 |
return colorize_pil(image)
|
| 305 |
except Exception as e:
|
| 306 |
logger.error("Gradio colorization error: %s", str(e))
|
|
|
|
| 28 |
import torch
|
| 29 |
import uvicorn
|
| 30 |
import gradio as gr
|
| 31 |
+
import numpy as np
|
| 32 |
+
import cv2
|
| 33 |
|
| 34 |
# FastAI imports
|
| 35 |
from fastai.vision.all import *
|
|
|
|
| 213 |
response["model_error"] = model_load_error
|
| 214 |
return response
|
| 215 |
|
| 216 |
+
def simple_colorize_fallback(image: Image.Image) -> Image.Image:
|
| 217 |
+
"""
|
| 218 |
+
Simple fallback colorization using LAB color space
|
| 219 |
+
This provides basic colorization when the model doesn't load
|
| 220 |
+
"""
|
| 221 |
+
# Convert to LAB color space
|
| 222 |
+
if image.mode != "RGB":
|
| 223 |
+
image = image.convert("RGB")
|
| 224 |
+
|
| 225 |
+
# Convert to numpy array
|
| 226 |
+
img_array = np.array(image)
|
| 227 |
+
|
| 228 |
+
# Convert RGB to LAB
|
| 229 |
+
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 230 |
+
|
| 231 |
+
# Apply simple colorization: enhance the L channel and add some color hints
|
| 232 |
+
l, a, b = cv2.split(lab)
|
| 233 |
+
|
| 234 |
+
# Enhance lightness
|
| 235 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
| 236 |
+
l = clahe.apply(l)
|
| 237 |
+
|
| 238 |
+
# Add subtle color hints to a and b channels
|
| 239 |
+
# This is a very basic approach - just adds some warm tones
|
| 240 |
+
a = np.clip(a.astype(np.float32) + 5, 0, 255).astype(np.uint8)
|
| 241 |
+
b = np.clip(b.astype(np.float32) + 10, 0, 255).astype(np.uint8)
|
| 242 |
+
|
| 243 |
+
# Merge channels and convert back to RGB
|
| 244 |
+
lab_colored = cv2.merge([l, a, b])
|
| 245 |
+
colored_rgb = cv2.cvtColor(lab_colored, cv2.COLOR_LAB2RGB)
|
| 246 |
+
|
| 247 |
+
return Image.fromarray(colored_rgb)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
def colorize_pil(image: Image.Image) -> Image.Image:
|
| 251 |
"""Run model prediction and return colorized image"""
|
| 252 |
# Try FastAI first
|
|
|
|
| 287 |
return pytorch_colorizer.colorize(image)
|
| 288 |
|
| 289 |
else:
|
| 290 |
+
# Final fallback: simple colorization
|
| 291 |
+
logger.warning("No model loaded, using simple colorization fallback")
|
| 292 |
+
return simple_colorize_fallback(image)
|
| 293 |
|
| 294 |
@app.post("/colorize")
|
| 295 |
async def colorize_api(
|
|
|
|
| 300 |
Upload a black & white image -> returns colorized image.
|
| 301 |
Requires Firebase authentication unless DISABLE_AUTH=true
|
| 302 |
"""
|
| 303 |
+
# Allow fallback colorization even if model isn't loaded
|
| 304 |
+
# if learn is None and pytorch_colorizer is None:
|
| 305 |
+
# raise HTTPException(status_code=503, detail="Colorization model not loaded")
|
| 306 |
|
| 307 |
if not file.content_type or not file.content_type.startswith("image/"):
|
| 308 |
raise HTTPException(status_code=400, detail="File must be an image")
|
|
|
|
| 338 |
if image is None:
|
| 339 |
return None
|
| 340 |
try:
|
| 341 |
+
# Always try to colorize, even with fallback
|
|
|
|
| 342 |
return colorize_pil(image)
|
| 343 |
except Exception as e:
|
| 344 |
logger.error("Gradio colorization error: %s", str(e))
|
requirements.txt
CHANGED
|
@@ -7,4 +7,6 @@ pillow
|
|
| 7 |
firebase-admin
|
| 8 |
fastai
|
| 9 |
huggingface_hub
|
| 10 |
-
pydantic-settings
|
|
|
|
|
|
|
|
|
| 7 |
firebase-admin
|
| 8 |
fastai
|
| 9 |
huggingface_hub
|
| 10 |
+
pydantic-settings
|
| 11 |
+
opencv-python
|
| 12 |
+
numpy
|