LogicGoInfotechSpaces commited on
Commit
f40725d
·
1 Parent(s): 963b208

Add simple colorization fallback using LAB color space when model fails to load

Browse files
Files changed (2) hide show
  1. app/main_fastai.py +43 -5
  2. requirements.txt +3 -1
app/main_fastai.py CHANGED
@@ -28,6 +28,8 @@ from PIL import Image
28
  import torch
29
  import uvicorn
30
  import gradio as gr
 
 
31
 
32
  # FastAI imports
33
  from fastai.vision.all import *
@@ -211,6 +213,40 @@ async def health_check():
211
  response["model_error"] = model_load_error
212
  return response
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def colorize_pil(image: Image.Image) -> Image.Image:
215
  """Run model prediction and return colorized image"""
216
  # Try FastAI first
@@ -251,7 +287,9 @@ def colorize_pil(image: Image.Image) -> Image.Image:
251
  return pytorch_colorizer.colorize(image)
252
 
253
  else:
254
- raise RuntimeError("No colorization model loaded")
 
 
255
 
256
  @app.post("/colorize")
257
  async def colorize_api(
@@ -262,8 +300,9 @@ async def colorize_api(
262
  Upload a black & white image -> returns colorized image.
263
  Requires Firebase authentication unless DISABLE_AUTH=true
264
  """
265
- if learn is None and pytorch_colorizer is None:
266
- raise HTTPException(status_code=503, detail="Colorization model not loaded")
 
267
 
268
  if not file.content_type or not file.content_type.startswith("image/"):
269
  raise HTTPException(status_code=400, detail="File must be an image")
@@ -299,8 +338,7 @@ def gradio_colorize(image):
299
  if image is None:
300
  return None
301
  try:
302
- if learn is None and pytorch_colorizer is None:
303
- return None
304
  return colorize_pil(image)
305
  except Exception as e:
306
  logger.error("Gradio colorization error: %s", str(e))
 
28
  import torch
29
  import uvicorn
30
  import gradio as gr
31
+ import numpy as np
32
+ import cv2
33
 
34
  # FastAI imports
35
  from fastai.vision.all import *
 
213
  response["model_error"] = model_load_error
214
  return response
215
 
216
+ def simple_colorize_fallback(image: Image.Image) -> Image.Image:
217
+ """
218
+ Simple fallback colorization using LAB color space
219
+ This provides basic colorization when the model doesn't load
220
+ """
221
+ # Convert to LAB color space
222
+ if image.mode != "RGB":
223
+ image = image.convert("RGB")
224
+
225
+ # Convert to numpy array
226
+ img_array = np.array(image)
227
+
228
+ # Convert RGB to LAB
229
+ lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
230
+
231
+ # Apply simple colorization: enhance the L channel and add some color hints
232
+ l, a, b = cv2.split(lab)
233
+
234
+ # Enhance lightness
235
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
236
+ l = clahe.apply(l)
237
+
238
+ # Add subtle color hints to a and b channels
239
+ # This is a very basic approach - just adds some warm tones
240
+ a = np.clip(a.astype(np.float32) + 5, 0, 255).astype(np.uint8)
241
+ b = np.clip(b.astype(np.float32) + 10, 0, 255).astype(np.uint8)
242
+
243
+ # Merge channels and convert back to RGB
244
+ lab_colored = cv2.merge([l, a, b])
245
+ colored_rgb = cv2.cvtColor(lab_colored, cv2.COLOR_LAB2RGB)
246
+
247
+ return Image.fromarray(colored_rgb)
248
+
249
+
250
  def colorize_pil(image: Image.Image) -> Image.Image:
251
  """Run model prediction and return colorized image"""
252
  # Try FastAI first
 
287
  return pytorch_colorizer.colorize(image)
288
 
289
  else:
290
+ # Final fallback: simple colorization
291
+ logger.warning("No model loaded, using simple colorization fallback")
292
+ return simple_colorize_fallback(image)
293
 
294
  @app.post("/colorize")
295
  async def colorize_api(
 
300
  Upload a black & white image -> returns colorized image.
301
  Requires Firebase authentication unless DISABLE_AUTH=true
302
  """
303
+ # Allow fallback colorization even if model isn't loaded
304
+ # if learn is None and pytorch_colorizer is None:
305
+ # raise HTTPException(status_code=503, detail="Colorization model not loaded")
306
 
307
  if not file.content_type or not file.content_type.startswith("image/"):
308
  raise HTTPException(status_code=400, detail="File must be an image")
 
338
  if image is None:
339
  return None
340
  try:
341
+ # Always try to colorize, even with fallback
 
342
  return colorize_pil(image)
343
  except Exception as e:
344
  logger.error("Gradio colorization error: %s", str(e))
requirements.txt CHANGED
@@ -7,4 +7,6 @@ pillow
7
  firebase-admin
8
  fastai
9
  huggingface_hub
10
- pydantic-settings
 
 
 
7
  firebase-admin
8
  fastai
9
  huggingface_hub
10
+ pydantic-settings
11
+ opencv-python
12
+ numpy