Local AI Assistant commited on
Commit
a016784
·
1 Parent(s): a33b5b3

Implement lazy loading for AI models to save memory

Browse files
Files changed (1) hide show
  1. api.py +36 -11
api.py CHANGED
@@ -57,12 +57,31 @@ if firebase_admin._apps:
57
  else:
58
  db = None
59
 
60
- # Initialize engines
61
- print("Initializing AI Engines...")
62
- chat_engine = ChatEngine()
63
- image_engine = ImageEngine()
64
- rag_engine = RAGEngine()
65
- print("AI Engines Ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Auth Dependency
68
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@@ -222,8 +241,10 @@ async def chat(request: ChatRequest, current_user: dict = Depends(get_current_us
222
  # ... (Keep existing /chat for backward compatibility if needed, or redirect logic)
223
  # For now, let's keep /chat as blocking and add /chat/stream
224
  try:
 
 
225
  # Generate Response
226
- response = chat_engine.generate_response(request.message, request.history)
227
 
228
  # Save to Firestore if conversation_id is present
229
  if request.conversation_id:
@@ -261,7 +282,8 @@ async def upload_file(file: UploadFile = File(...), current_user: dict = Depends
261
  shutil.copyfileobj(file.file, buffer)
262
 
263
  # Ingest into RAG
264
- rag_engine.ingest_file(file_path)
 
265
 
266
  return {"filename": file.filename, "status": "ingested"}
267
  except Exception as e:
@@ -272,7 +294,8 @@ async def chat_stream(request: ChatRequest, current_user: dict = Depends(get_cur
272
  try:
273
  # Check for RAG context
274
  context = ""
275
- rag_docs = rag_engine.search(request.message)
 
276
  if rag_docs:
277
  context = "\n\nRelevant Context:\n" + "\n".join(rag_docs) + "\n\n"
278
  print(f"Found {len(rag_docs)} relevant documents.")
@@ -292,7 +315,8 @@ async def chat_stream(request: ChatRequest, current_user: dict = Depends(get_cur
292
  # Prepend context to the message sent to AI (but not saved in DB as user message)
293
  augmented_message = context + request.message if context else request.message
294
 
295
- for token in chat_engine.generate_stream(augmented_message, request.history, request.language):
 
296
  full_response += token
297
  yield token
298
 
@@ -319,7 +343,8 @@ async def generate_image(request: ImageRequest, current_user: dict = Depends(get
319
  try:
320
  # Generate image to a temporary file
321
  filename = "temp_generated.png"
322
- image_engine.generate_image(request.prompt, output_path=filename)
 
323
 
324
  # Read and encode to base64 to send to frontend
325
  with open(filename, "rb") as image_file:
 
57
  else:
58
  db = None
59
 
60
+ # Global engine instances (Lazy loaded)
61
+ chat_engine = None
62
+ image_engine = None
63
+ rag_engine = None
64
+
65
+ def get_chat_engine():
66
+ global chat_engine
67
+ if chat_engine is None:
68
+ print("Lazy loading Chat Engine...")
69
+ chat_engine = ChatEngine()
70
+ return chat_engine
71
+
72
+ def get_image_engine():
73
+ global image_engine
74
+ if image_engine is None:
75
+ print("Lazy loading Image Engine...")
76
+ image_engine = ImageEngine()
77
+ return image_engine
78
+
79
+ def get_rag_engine():
80
+ global rag_engine
81
+ if rag_engine is None:
82
+ print("Lazy loading RAG Engine...")
83
+ rag_engine = RAGEngine()
84
+ return rag_engine
85
 
86
  # Auth Dependency
87
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
241
  # ... (Keep existing /chat for backward compatibility if needed, or redirect logic)
242
  # For now, let's keep /chat as blocking and add /chat/stream
243
  try:
244
+ # Get engine (lazy load)
245
+ engine = get_chat_engine()
246
  # Generate Response
247
+ response = engine.generate_response(request.message, request.history)
248
 
249
  # Save to Firestore if conversation_id is present
250
  if request.conversation_id:
 
282
  shutil.copyfileobj(file.file, buffer)
283
 
284
  # Ingest into RAG
285
+ rag = get_rag_engine()
286
+ rag.ingest_file(file_path)
287
 
288
  return {"filename": file.filename, "status": "ingested"}
289
  except Exception as e:
 
294
  try:
295
  # Check for RAG context
296
  context = ""
297
+ rag = get_rag_engine()
298
+ rag_docs = rag.search(request.message)
299
  if rag_docs:
300
  context = "\n\nRelevant Context:\n" + "\n".join(rag_docs) + "\n\n"
301
  print(f"Found {len(rag_docs)} relevant documents.")
 
315
  # Prepend context to the message sent to AI (but not saved in DB as user message)
316
  augmented_message = context + request.message if context else request.message
317
 
318
+ engine = get_chat_engine()
319
+ for token in engine.generate_stream(augmented_message, request.history, request.language):
320
  full_response += token
321
  yield token
322
 
 
343
  try:
344
  # Generate image to a temporary file
345
  filename = "temp_generated.png"
346
+ engine = get_image_engine()
347
+ engine.generate_image(request.prompt, output_path=filename)
348
 
349
  # Read and encode to base64 to send to frontend
350
  with open(filename, "rb") as image_file: