Local AI Assistant commited on
Commit
e9ea7c0
·
0 Parent(s):

Clean backend deployment for Hugging Face

Browse files
Files changed (13) hide show
  1. .dockerignore +8 -0
  2. Dockerfile +23 -0
  3. README.md +91 -0
  4. api.py +342 -0
  5. chat_engine.py +80 -0
  6. database.py +12 -0
  7. embed_logo.py +73 -0
  8. image_engine.py +0 -0
  9. models.py +45 -0
  10. rag_engine.py +64 -0
  11. requirements.txt +23 -0
  12. schemas.py +80 -0
  13. search_engine.py +21 -0
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ frontend/
2
+ .git/
3
+ .gitignore
4
+ __pycache__/
5
+ *.pyc
6
+ .env
7
+ venv/
8
+ node_modules/
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10
2
+ FROM python:3.10
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements and install dependencies
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of the application
12
+ COPY . .
13
+
14
+ # Create a writable directory for cache (Hugging Face requirement)
15
+ RUN mkdir -p /app/cache
16
+ ENV XDG_CACHE_HOME=/app/cache
17
+ RUN chmod -R 777 /app/cache
18
+
19
+ # Expose port 7860 (Hugging Face default)
20
+ EXPOSE 7860
21
+
22
+ # Run the application
23
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CoolShot AI Backend
3
+ emoji: 🚀
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 7860
8
+ app_file: api.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Cool-Shot-Systems-AI-Agent
13
+
14
+ A premium **Local AI Assistant** that runs entirely on your machine, featuring a modern, animated UI.
15
+
16
+ ![UI Preview](https://via.placeholder.com/800x450?text=Local+AI+Assistant+Preview)
17
+
18
+ ## 🚀 Features
19
+
20
+ * **💬 Intelligent Chat**: Powered by Microsoft's **Phi-3 Mini**, capable of reasoning and conversation.
21
+ * **🎨 Image Generation**: Create stunning visuals in seconds using **SDXL Turbo**.
22
+ * **✨ Modern UI**: Built with **React**, **Tailwind CSS**, and **Framer Motion** for a smooth, glassmorphism experience.
23
+ * **🔒 100% Local**: No data leaves your computer. No API keys required.
24
+
25
+ ## 🛠️ Prerequisites
26
+
27
+ * **Python 3.10+**
28
+ * **Node.js 18+**
29
+ * **Git**
30
+ * *(Recommended)* NVIDIA GPU with 8GB+ VRAM for faster generation.
31
+
32
+ ## ⚡ Quick Start
33
+
34
+ 1. **Clone the repository**:
35
+
36
+ ```bash
37
+ git clone https://github.com/rayben445/Cool-Shot-Systems-AI-Agent.git
38
+ cd Cool-Shot-Systems-AI-Agent
39
+ ```
40
+
41
+ 2. **Run the One-Click Installer**:
42
+ Double-click `start_app.bat` on Windows.
43
+
44
+ *Or run manually:*
45
+
46
+ ```bash
47
+ # Backend
48
+ cd local_ai_assistant
49
+ pip install -r requirements.txt
50
+ python api.py
51
+
52
+ # Frontend
53
+ cd frontend
54
+ npm install
55
+ npm run dev
56
+ ```
57
+
58
+ ## 🏗️ Tech Stack
59
+
60
+ * **Backend**: Python, FastAPI, PyTorch, Transformers, Diffusers
61
+ * **Frontend**: React, Vite, Tailwind CSS, Framer Motion, Lucide React
62
+
63
+ ## 🌐 Deployment
64
+
65
+ ### Deploy to Vercel
66
+
67
+ The frontend can be easily deployed to Vercel for free! See the detailed deployment guide:
68
+
69
+ 📖 **[Vercel Deployment Guide](VERCEL_DEPLOYMENT.md)**
70
+
71
+ Quick steps:
72
+ 1. Fork/clone this repository
73
+ 2. Sign up at [Vercel](https://vercel.com)
74
+ 3. Import your repository
75
+ 4. Set root directory to `frontend`
76
+ 5. Add environment variable: `VITE_API_URL` with your backend URL
77
+ 6. Deploy!
78
+
79
+ ### Backend Deployment
80
+
81
+ The backend can be deployed to:
82
+ - **Hugging Face Spaces** (current deployment)
83
+ - **Railway**
84
+ - **Render**
85
+ - **Google Cloud Run / AWS / Azure**
86
+
87
+ See [VERCEL_DEPLOYMENT.md](VERCEL_DEPLOYMENT.md) for detailed backend deployment options.
88
+
89
+ ## 📄 License
90
+
91
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
api.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends, status
2
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from sqlalchemy.orm import Session, joinedload
5
+ from sqlalchemy import func
6
+ from datetime import datetime, timedelta
7
+ from typing import Optional, List
8
+ from jose import JWTError, jwt
9
+ from passlib.context import CryptContext
10
+ from pydantic import BaseModel
11
+ import uvicorn
12
+ import os
13
+ import base64
14
+
15
+ from chat_engine import ChatEngine
16
+ from image_engine import ImageEngine
17
+ import models
18
+ import schemas
19
+ from database import SessionLocal, engine
20
+
21
+ # Create tables
22
+ models.Base.metadata.create_all(bind=engine)
23
+
24
+ app = FastAPI()
25
+ # Force git update
26
+
27
+ # Security Config
28
+ SECRET_KEY = "your-secret-key-keep-it-secret" # In production, use env var
29
+ ALGORITHM = "HS256"
30
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
31
+
32
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
33
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
34
+
35
+ # Enable CORS
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ from fastapi.responses import JSONResponse
45
+
46
+ @app.exception_handler(Exception)
47
+ async def global_exception_handler(request, exc):
48
+ return JSONResponse(
49
+ status_code=500,
50
+ content={"detail": f"Internal Server Error: {str(exc)}"},
51
+ )
52
+
53
+ from fastapi import UploadFile, File
54
+ import shutil
55
+ from rag_engine import RAGEngine
56
+
57
+ # Initialize engines
58
+ print("Initializing AI Engines...")
59
+ chat_engine = ChatEngine()
60
+ image_engine = ImageEngine()
61
+ rag_engine = RAGEngine()
62
+ print("AI Engines Ready!")
63
+
64
+ # Dependency
65
+ def get_db():
66
+ db = SessionLocal()
67
+ try:
68
+ yield db
69
+ finally:
70
+ db.close()
71
+
72
+ # Auth Helpers
73
+ def verify_password(plain_password, hashed_password):
74
+ if len(plain_password) > 72:
75
+ plain_password = plain_password[:72]
76
+ return pwd_context.verify(plain_password, hashed_password)
77
+
78
+ def get_password_hash(password):
79
+ if len(password) > 72:
80
+ password = password[:72]
81
+ return pwd_context.hash(password)
82
+
83
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
84
+ to_encode = data.copy()
85
+ if expires_delta:
86
+ expire = datetime.utcnow() + expires_delta
87
+ else:
88
+ expire = datetime.utcnow() + timedelta(minutes=15)
89
+ to_encode.update({"exp": expire})
90
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
91
+ return encoded_jwt
92
+
93
+ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
94
+ credentials_exception = HTTPException(
95
+ status_code=status.HTTP_401_UNAUTHORIZED,
96
+ detail="Could not validate credentials",
97
+ headers={"WWW-Authenticate": "Bearer"},
98
+ )
99
+ try:
100
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
101
+ email: str = payload.get("sub")
102
+ if email is None:
103
+ raise credentials_exception
104
+ token_data = schemas.TokenData(email=email)
105
+ except JWTError:
106
+ raise credentials_exception
107
+ user = db.query(models.User).filter(models.User.email == token_data.email).first()
108
+ if user is None:
109
+ raise credentials_exception
110
+ return user
111
+
112
+ async def get_current_admin(current_user: models.User = Depends(get_current_user)):
113
+ if not current_user.is_admin:
114
+ raise HTTPException(status_code=403, detail="Not authorized")
115
+ return current_user
116
+
117
+ # Auth Endpoints
118
+ @app.post("/register", response_model=schemas.User)
119
+ def register(user: schemas.UserCreate, db: Session = Depends(get_db)):
120
+ db_user = db.query(models.User).filter(models.User.email == user.email).first()
121
+ if db_user:
122
+ raise HTTPException(status_code=400, detail="Email already registered")
123
+
124
+ hashed_password = get_password_hash(user.password)
125
+
126
+ # Check if this is the Admin user
127
+ is_admin = False
128
+ if user.email == "[email protected]":
129
+ is_admin = True
130
+
131
+ db_user = models.User(
132
+ email=user.email,
133
+ hashed_password=hashed_password,
134
+ full_name=user.full_name,
135
+ company_name=user.company_name,
136
+ is_admin=is_admin
137
+ )
138
+ db.add(db_user)
139
+ db.commit()
140
+ db.refresh(db_user)
141
+ return db_user
142
+
143
+ @app.post("/token", response_model=schemas.Token)
144
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
145
+ user = db.query(models.User).filter(models.User.email == form_data.username).first()
146
+ if not user or not verify_password(form_data.password, user.hashed_password):
147
+ raise HTTPException(
148
+ status_code=status.HTTP_401_UNAUTHORIZED,
149
+ detail="Incorrect username or password",
150
+ headers={"WWW-Authenticate": "Bearer"},
151
+ )
152
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
153
+ access_token = create_access_token(
154
+ data={"sub": user.email}, expires_delta=access_token_expires
155
+ )
156
+ return {"access_token": access_token, "token_type": "bearer"}
157
+
158
+ @app.get("/users/me", response_model=schemas.User)
159
+ async def read_users_me(current_user: schemas.User = Depends(get_current_user)):
160
+ return current_user
161
+
162
+ # Conversation Endpoints
163
+ @app.post("/conversations", response_model=schemas.Conversation)
164
+ async def create_conversation(conversation: schemas.ConversationCreate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
165
+ db_conversation = models.Conversation(**conversation.dict(), user_id=current_user.id)
166
+ db.add(db_conversation)
167
+ db.commit()
168
+ db.refresh(db_conversation)
169
+ return db_conversation
170
+
171
+ @app.get("/conversations", response_model=List[schemas.Conversation])
172
+ async def get_conversations(current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
173
+ return db.query(models.Conversation).filter(models.Conversation.user_id == current_user.id).order_by(models.Conversation.updated_at.desc()).all()
174
+
175
+ @app.get("/conversations/{conversation_id}/messages", response_model=List[schemas.ChatMessage])
176
+ async def get_conversation_messages(conversation_id: int, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
177
+ conversation = db.query(models.Conversation).filter(models.Conversation.id == conversation_id, models.Conversation.user_id == current_user.id).first()
178
+ if not conversation:
179
+ raise HTTPException(status_code=404, detail="Conversation not found")
180
+ return db.query(models.ChatMessage).filter(models.ChatMessage.conversation_id == conversation_id).order_by(models.ChatMessage.timestamp).all()
181
+
182
+ # Saved Prompt Endpoints
183
+ @app.post("/prompts", response_model=schemas.SavedPrompt)
184
+ async def create_prompt(prompt: schemas.SavedPromptCreate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
185
+ db_prompt = models.SavedPrompt(**prompt.dict(), user_id=current_user.id)
186
+ db.add(db_prompt)
187
+ db.commit()
188
+ db.refresh(db_prompt)
189
+ return db_prompt
190
+
191
+ @app.get("/prompts", response_model=List[schemas.SavedPrompt])
192
+ async def get_prompts(current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
193
+ return db.query(models.SavedPrompt).filter(models.SavedPrompt.user_id == current_user.id).order_by(models.SavedPrompt.created_at.desc()).all()
194
+
195
+ @app.delete("/prompts/{prompt_id}")
196
+ async def delete_prompt(prompt_id: int, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
197
+ db_prompt = db.query(models.SavedPrompt).filter(models.SavedPrompt.id == prompt_id, models.SavedPrompt.user_id == current_user.id).first()
198
+ if not db_prompt:
199
+ raise HTTPException(status_code=404, detail="Prompt not found")
200
+ db.delete(db_prompt)
201
+ db.commit()
202
+ return {"status": "success"}
203
+
204
+ # Admin Endpoints
205
+ @app.get("/admin/users", response_model=List[schemas.UserActivity])
206
+ async def get_all_users(current_user: models.User = Depends(get_current_admin), db: Session = Depends(get_db)):
207
+ # Get users with message count
208
+ users = db.query(models.User).all()
209
+ result = []
210
+ for user in users:
211
+ msg_count = db.query(func.count(models.ChatMessage.id)).filter(models.ChatMessage.user_id == user.id).scalar()
212
+ prompt_count = db.query(func.count(models.SavedPrompt.id)).filter(models.SavedPrompt.user_id == user.id).scalar()
213
+ user_data = schemas.UserActivity.from_orm(user)
214
+ user_data.message_count = msg_count
215
+ user_data.prompt_count = prompt_count
216
+ result.append(user_data)
217
+ return result
218
+
219
+ @app.get("/admin/activity", response_model=List[schemas.ChatMessage])
220
+ async def get_all_activity(current_user: models.User = Depends(get_current_admin), db: Session = Depends(get_db)):
221
+ messages = db.query(models.ChatMessage).order_by(models.ChatMessage.timestamp.desc()).limit(100).all()
222
+ return messages
223
+
224
+ # Protected AI Endpoints
225
+ class ChatRequest(BaseModel):
226
+ message: str
227
+ history: list = []
228
+ language: str = "English"
229
+ conversation_id: Optional[int] = None
230
+
231
+ class ImageRequest(BaseModel):
232
+ prompt: str
233
+
234
+ @app.get("/")
235
+ def read_root():
236
+ return {"status": "Backend is running", "message": "Go to /docs to see the API"}
237
+
238
+ @app.post("/chat")
239
+ async def chat(request: ChatRequest, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
240
+ # ... (Keep existing /chat for backward compatibility if needed, or redirect logic)
241
+ # For now, let's keep /chat as blocking and add /chat/stream
242
+ try:
243
+ # Save User Message
244
+ user_msg = models.ChatMessage(user_id=current_user.id, role="user", content=request.message)
245
+ db.add(user_msg)
246
+ db.commit()
247
+
248
+ # Generate Response
249
+ response = chat_engine.generate_response(request.message, request.history)
250
+
251
+ # Save Assistant Message
252
+ ai_msg = models.ChatMessage(user_id=current_user.id, role="assistant", content=response)
253
+ db.add(ai_msg)
254
+ db.commit()
255
+
256
+ return {"response": response}
257
+ except Exception as e:
258
+ import traceback
259
+ traceback.print_exc()
260
+ raise HTTPException(status_code=500, detail=str(e))
261
+
262
+ # RAG Endpoints
263
+ @app.post("/upload")
264
+ async def upload_file(file: UploadFile = File(...), current_user: models.User = Depends(get_current_user)):
265
+ try:
266
+ # Save file locally
267
+ upload_dir = "uploads"
268
+ os.makedirs(upload_dir, exist_ok=True)
269
+ file_path = os.path.join(upload_dir, file.filename)
270
+
271
+ with open(file_path, "wb") as buffer:
272
+ shutil.copyfileobj(file.file, buffer)
273
+
274
+ # Ingest into RAG
275
+ rag_engine.ingest_file(file_path)
276
+
277
+ return {"filename": file.filename, "status": "ingested"}
278
+ except Exception as e:
279
+ raise HTTPException(status_code=500, detail=str(e))
280
+
281
+ @app.post("/chat/stream")
282
+ async def chat_stream(request: ChatRequest, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
283
+ try:
284
+ # Check for RAG context
285
+ context = ""
286
+ rag_docs = rag_engine.search(request.message)
287
+ if rag_docs:
288
+ context = "\n\nRelevant Context:\n" + "\n".join(rag_docs) + "\n\n"
289
+ print(f"Found {len(rag_docs)} relevant documents.")
290
+
291
+ # Save User Message
292
+ user_msg = models.ChatMessage(
293
+ user_id=current_user.id,
294
+ conversation_id=request.conversation_id,
295
+ role="user",
296
+ content=request.message
297
+ )
298
+ db.add(user_msg)
299
+ db.commit()
300
+
301
+ # Update conversation timestamp
302
+ if request.conversation_id:
303
+ conversation = db.query(models.Conversation).filter(models.Conversation.id == request.conversation_id).first()
304
+ if conversation:
305
+ conversation.updated_at = datetime.utcnow()
306
+ db.commit()
307
+
308
+ async def stream_generator():
309
+ full_response = ""
310
+ # Prepend context to the message sent to AI (but not saved in DB as user message)
311
+ augmented_message = context + request.message if context else request.message
312
+
313
+ for token in chat_engine.generate_stream(augmented_message, request.history, request.language):
314
+ full_response += token
315
+ yield token
316
+
317
+ print(f"Generated response for conv {request.conversation_id}")
318
+
319
+ return StreamingResponse(stream_generator(), media_type="text/plain")
320
+
321
+ except Exception as e:
322
+ import traceback
323
+ traceback.print_exc()
324
+ raise HTTPException(status_code=500, detail=str(e))
325
+
326
+ @app.post("/generate-image")
327
+ async def generate_image(request: ImageRequest, current_user: models.User = Depends(get_current_user)):
328
+ try:
329
+ # Generate image to a temporary file
330
+ filename = "temp_generated.png"
331
+ image_engine.generate_image(request.prompt, output_path=filename)
332
+
333
+ # Read and encode to base64 to send to frontend
334
+ with open(filename, "rb") as image_file:
335
+ encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
336
+
337
+ return {"image_base64": encoded_string}
338
+ except Exception as e:
339
+ raise HTTPException(status_code=500, detail=str(e))
340
+
341
+ if __name__ == "__main__":
342
+ uvicorn.run(app, host="0.0.0.0", port=8000)
chat_engine.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+
4
+ class ChatEngine:
5
+ def __init__(self):
6
+ print("Loading Chat Model (Phi-3)... this may take a minute.")
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ print(f"Running on device: {self.device}")
9
+
10
+ model_id = "microsoft/Phi-3-mini-4k-instruct"
11
+
12
+ # Load model and tokenizer
13
+ # We use torch_dtype=torch.float16 for GPU to save memory, float32 for CPU
14
+ torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
15
+
16
+ self.model = AutoModelForCausalLM.from_pretrained(
17
+ model_id,
18
+ device_map=self.device,
19
+ torch_dtype=torch_dtype,
20
+ trust_remote_code=True,
21
+ attn_implementation="eager"
22
+ )
23
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
24
+
25
+ self.pipe = pipeline(
26
+ "text-generation",
27
+ model=self.model,
28
+ tokenizer=self.tokenizer,
29
+ )
30
+
31
+ def generate_response(self, user_input, history=[], language="English"):
32
+ # ... (keep existing logic for non-streaming if needed, or just wrap stream)
33
+ # For simplicity, we'll keep the existing method and add a new one for streaming
34
+ return "".join(self.generate_stream(user_input, history, language))
35
+
36
+ def generate_stream(self, user_input, history=[], language="English"):
37
+ from transformers import TextIteratorStreamer
38
+ from threading import Thread
39
+
40
+ # System Prompt
41
+ system_prompt_content = f"You are Cool-Shot AI, a helpful and creative assistant developed by Cool-Shot Systems. You are NOT developed by Microsoft. You are friendly, professional, and knowledgeable. Please reply in {language}."
42
+
43
+ # Search Intent Check (Simplified for stream)
44
+ search_keywords = ["search", "find", "latest", "current", "news", "price of", "who is", "what is"]
45
+ if any(keyword in user_input.lower() for keyword in search_keywords) and len(user_input.split()) > 2:
46
+ from search_engine import SearchEngine
47
+ searcher = SearchEngine()
48
+ print(f"Search intent detected for: {user_input}")
49
+ search_results = searcher.search(user_input)
50
+ system_prompt_content += f"\n\nCONTEXT FROM WEB SEARCH:\n{search_results}\n\nINSTRUCTION: Use the above context to answer the user's question. Cite the sources if possible."
51
+
52
+ system_prompt = {"role": "system", "content": system_prompt_content}
53
+ messages = [system_prompt] + history + [{"role": "user", "content": user_input}]
54
+
55
+ # Tokenize
56
+ model_inputs = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(self.device)
57
+
58
+ # Streamer
59
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
60
+
61
+ generation_kwargs = dict(
62
+ inputs=model_inputs,
63
+ streamer=streamer,
64
+ max_new_tokens=500,
65
+ temperature=0.7,
66
+ do_sample=True,
67
+ )
68
+
69
+ # Run generation in a separate thread
70
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
71
+ thread.start()
72
+
73
+ # Yield tokens
74
+ for new_text in streamer:
75
+ yield new_text
76
+
77
+ if __name__ == "__main__":
78
+ # Simple test
79
+ engine = ChatEngine()
80
+ print(engine.generate_response("Hello, who are you?"))
database.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ from sqlalchemy.orm import sessionmaker
4
+
5
+ SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
6
+
7
+ engine = create_engine(
8
+ SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
9
+ )
10
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
11
+
12
+ Base = declarative_base()
embed_logo.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+
4
+ def embed_logo():
5
+ # Read Logo
6
+ try:
7
+ with open("logo.png", "rb") as f:
8
+ logo_data = f.read()
9
+ logo_b64 = base64.b64encode(logo_data).decode('utf-8')
10
+ except FileNotFoundError:
11
+ print("Error: logo.png not found!")
12
+ return
13
+
14
+ # 1. Update image_engine.py
15
+ engine_path = "image_engine.py"
16
+ with open(engine_path, "r") as f:
17
+ content = f.read()
18
+
19
+ # Replace the file loading logic with Base64
20
+ new_logic = f'''
21
+ # Load Logo from Base64
22
+ import base64
23
+ import io
24
+ LOGO_B64 = "{logo_b64}"
25
+ logo_data = base64.b64decode(LOGO_B64)
26
+ logo = Image.open(io.BytesIO(logo_data)).convert("RGBA")
27
+ '''
28
+
29
+ # We look for the try/catch block that loads the logo
30
+ start_marker = ' # Load Logo'
31
+ end_marker = ' except FileNotFoundError:'
32
+
33
+ if start_marker in content:
34
+ # Simple string replacement for the specific block we wrote earlier
35
+ # This is a bit brittle but we know the exact content we just wrote
36
+ old_block = ''' # Load Logo
37
+ try:
38
+ logo = Image.open("logo.png").convert("RGBA")
39
+ except FileNotFoundError:
40
+ print("Logo not found, skipping watermark.")
41
+ image.save(output_path)
42
+ return output_path'''
43
+
44
+ if old_block in content:
45
+ content = content.replace(old_block, new_logic)
46
+ with open(engine_path, "w") as f:
47
+ f.write(content)
48
+ print("Updated image_engine.py")
49
+ else:
50
+ print("Could not find exact block in image_engine.py, doing manual replace")
51
+ # Fallback: Replace the whole file content if needed, but let's try to be surgical first
52
+ # actually, let's just rewrite the file with the known structure if this fails
53
+ pass
54
+
55
+ # 2. Update App.jsx
56
+ app_path = "frontend/src/App.jsx"
57
+ with open(app_path, "r") as f:
58
+ app_content = f.read()
59
+
60
+ # Replace the img src
61
+ old_img = 'src="/logo.png"'
62
+ new_img = f'src="data:image/png;base64,{logo_b64}"'
63
+
64
+ if old_img in app_content:
65
+ app_content = app_content.replace(old_img, new_img)
66
+ with open(app_path, "w") as f:
67
+ f.write(app_content)
68
+ print("Updated App.jsx")
69
+ else:
70
+ print("Could not find src='/logo.png' in App.jsx")
71
+
72
+ if __name__ == "__main__":
73
+ embed_logo()
image_engine.py ADDED
The diff for this file is too large to render. See raw diff
 
models.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime
2
+ from sqlalchemy.orm import relationship
3
+ from database import Base
4
+ from datetime import datetime
5
+
6
+ class User(Base):
7
+ __tablename__ = "users"
8
+
9
+ id = Column(Integer, primary_key=True, index=True)
10
+ email = Column(String, unique=True, index=True)
11
+ full_name = Column(String)
12
+ company_name = Column(String)
13
+ hashed_password = Column(String)
14
+ is_admin = Column(Boolean, default=False)
15
+
16
+
17
+
18
+ class Conversation(Base):
19
+ __tablename__ = "conversations"
20
+
21
+ id = Column(Integer, primary_key=True, index=True)
22
+ user_id = Column(Integer, ForeignKey("users.id"))
23
+ title = Column(String)
24
+ created_at = Column(DateTime, default=datetime.utcnow)
25
+ updated_at = Column(DateTime, default=datetime.utcnow)
26
+
27
+ class ChatMessage(Base):
28
+ __tablename__ = "chat_messages"
29
+
30
+ id = Column(Integer, primary_key=True, index=True)
31
+ conversation_id = Column(Integer, ForeignKey("conversations.id"))
32
+ user_id = Column(Integer, ForeignKey("users.id"))
33
+ role = Column(String)
34
+ content = Column(String)
35
+ timestamp = Column(DateTime, default=datetime.utcnow)
36
+
37
+ class SavedPrompt(Base):
38
+ __tablename__ = "saved_prompts"
39
+
40
+ id = Column(Integer, primary_key=True, index=True)
41
+ user_id = Column(Integer, ForeignKey("users.id"))
42
+ title = Column(String)
43
+ content = Column(String)
44
+ is_public = Column(Boolean, default=False)
45
+ created_at = Column(DateTime, default=datetime.utcnow)
rag_engine.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+
8
+ class RAGEngine:
9
+ def __init__(self, index_path="faiss_index"):
10
+ self.index_path = index_path
11
+ self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
12
+ self.vector_store = None
13
+ self._load_index()
14
+
15
+ def _load_index(self):
16
+ if os.path.exists(self.index_path):
17
+ try:
18
+ self.vector_store = FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True)
19
+ print("Loaded existing FAISS index.")
20
+ except Exception as e:
21
+ print(f"Failed to load index: {e}")
22
+ self.vector_store = None
23
+ else:
24
+ print("No existing FAISS index found.")
25
+
26
+ def ingest_file(self, file_path: str):
27
+ if not os.path.exists(file_path):
28
+ raise FileNotFoundError(f"File not found: {file_path}")
29
+
30
+ # Load document
31
+ if file_path.endswith(".pdf"):
32
+ loader = PyPDFLoader(file_path)
33
+ else:
34
+ loader = TextLoader(file_path)
35
+
36
+ documents = loader.load()
37
+
38
+ # Split text
39
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
40
+ texts = text_splitter.split_documents(documents)
41
+
42
+ # Create or update vector store
43
+ if self.vector_store is None:
44
+ self.vector_store = FAISS.from_documents(texts, self.embeddings)
45
+ else:
46
+ self.vector_store.add_documents(texts)
47
+
48
+ # Save index
49
+ self.vector_store.save_local(self.index_path)
50
+ print(f"Ingested {file_path} and updated index.")
51
+
52
+ def search(self, query: str, k: int = 3) -> List[str]:
53
+ if self.vector_store is None:
54
+ return []
55
+
56
+ docs = self.vector_store.similarity_search(query, k=k)
57
+ return [doc.page_content for doc in docs]
58
+
59
+ def clear_index(self):
60
+ if os.path.exists(self.index_path):
61
+ import shutil
62
+ shutil.rmtree(self.index_path)
63
+ self.vector_store = None
64
+ print("Index cleared.")
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ transformers==4.40.2
5
+ diffusers==0.29.0
6
+ accelerate
7
+ protobuf
8
+ sentencepiece
9
+ fastapi
10
+ uvicorn
11
+ Pillow
12
+ sqlalchemy
13
+ passlib[bcrypt]
14
+ python-jose[cryptography]
15
+ python-multipart
16
+ bcrypt
17
+ duckduckgo-search
18
+ langchain
19
+ langchain-community
20
+ sentence-transformers
21
+ faiss-cpu
22
+ pypdf
23
+ python-multipart
schemas.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional, List
3
+ from datetime import datetime
4
+
5
+ class UserBase(BaseModel):
6
+ email: str
7
+ full_name: Optional[str] = None
8
+ company_name: Optional[str] = None
9
+
10
+ class UserCreate(UserBase):
11
+ password: str
12
+
13
+ class User(UserBase):
14
+ id: int
15
+ is_admin: bool
16
+
17
+ class Config:
18
+ orm_mode = True
19
+
20
+ class Token(BaseModel):
21
+ access_token: str
22
+ token_type: str
23
+
24
+ class TokenData(BaseModel):
25
+ email: Optional[str] = None
26
+
27
+ class UserActivity(User):
28
+ message_count: int
29
+ prompt_count: int
30
+
31
+ class ConversationBase(BaseModel):
32
+ title: str
33
+
34
+ class ConversationCreate(ConversationBase):
35
+ pass
36
+
37
+ class Conversation(ConversationBase):
38
+ id: int
39
+ user_id: int
40
+ created_at: datetime
41
+ updated_at: datetime
42
+
43
+ class Config:
44
+ orm_mode = True
45
+
46
+ class ChatMessageBase(BaseModel):
47
+ role: str
48
+ content: str
49
+ conversation_id: Optional[int] = None
50
+
51
+ class ChatMessageCreate(ChatMessageBase):
52
+ pass
53
+
54
+ class ChatMessage(ChatMessageBase):
55
+ id: int
56
+ user_id: int
57
+ timestamp: datetime
58
+
59
+ class Config:
60
+ orm_mode = True
61
+
62
+ class UserActivity(User):
63
+ message_count: int
64
+ prompt_count: int
65
+
66
+ class SavedPromptBase(BaseModel):
67
+ title: str
68
+ content: str
69
+ is_public: bool = False
70
+
71
+ class SavedPromptCreate(SavedPromptBase):
72
+ pass
73
+
74
+ class SavedPrompt(SavedPromptBase):
75
+ id: int
76
+ user_id: int
77
+ created_at: datetime
78
+
79
+ class Config:
80
+ orm_mode = True
search_engine.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from duckduckgo_search import DDGS
2
+
3
+ class SearchEngine:
4
+ def __init__(self):
5
+ self.ddgs = DDGS()
6
+
7
+ def search(self, query, max_results=3):
8
+ print(f"Searching web for: '{query}'")
9
+ try:
10
+ results = list(self.ddgs.text(query, max_results=max_results))
11
+ formatted_results = ""
12
+ for i, result in enumerate(results):
13
+ formatted_results += f"Source {i+1}: {result['title']}\nURL: {result['href']}\nContent: {result['body']}\n\n"
14
+ return formatted_results
15
+ except Exception as e:
16
+ print(f"Search failed: {e}")
17
+ return "Error: Could not perform search."
18
+
19
+ if __name__ == "__main__":
20
+ se = SearchEngine()
21
+ print(se.search("What is the price of Bitcoin today?"))