File size: 8,613 Bytes
eac6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# rag_system.py

import os
import logging
import shutil
import json
from typing import Optional

from rag_components import KnowledgeRAG
from utils import download_and_unzip_gdrive_folder
from config import (
    GROQ_API_KEY, GDRIVE_SOURCES_ENABLED, GDRIVE_FOLDER_ID_OR_URL, RAG_SOURCES_DIR,
    RAG_STORAGE_PARENT_DIR, RAG_FAISS_INDEX_SUBDIR_NAME, RAG_LOAD_INDEX_ON_STARTUP,
    RAG_EMBEDDING_MODEL_NAME, RAG_LLM_MODEL_NAME,
    RAG_EMBEDDING_USE_GPU, RAG_LLM_TEMPERATURE, RAG_CHUNK_SIZE, RAG_CHUNK_OVERLAP,
    RAG_RERANKER_MODEL_NAME, RAG_RERANKER_ENABLED, RAG_CHUNKED_SOURCES_FILENAME
)

logger = logging.getLogger(__name__)

# MODIFIED: Added source_dir_override parameter
def initialize_and_get_rag_system(force_rebuild: bool = False, source_dir_override: Optional[str] = None) -> Optional[KnowledgeRAG]:
    """

    Initializes and returns the KnowledgeRAG system.

    Can force a rebuild by deleting the existing index first.

    Uses module-level configuration constants.

    Downloads sources from GDrive if configured.

    """
    logger.info("[RAG_SYSTEM_INIT] ========== Initializing RAG System ==========")
    
    if not GROQ_API_KEY:
        logger.error("[RAG_SYSTEM_INIT] Groq API Key (BOT_API_KEY) not found. RAG system cannot be initialized.")
        return None

    # MODIFIED: Determine the source directory to use
    source_dir_to_use = source_dir_override if source_dir_override and os.path.isdir(source_dir_override) else RAG_SOURCES_DIR
    if source_dir_override and not os.path.isdir(source_dir_override):
        logger.error(f"[RAG_SYSTEM_INIT] Custom source directory override '{source_dir_override}' not found. Aborting.")
        return None # Or handle error appropriately
    
    logger.info(f"[RAG_SYSTEM_INIT] Using source directory: '{source_dir_to_use}'")

    if GDRIVE_SOURCES_ENABLED and not source_dir_override: # Only download if not using a custom directory
        logger.info("[RAG_SYSTEM_INIT] Google Drive sources download is ENABLED")
        if GDRIVE_FOLDER_ID_OR_URL:
            # ... (rest of GDrive logic is unchanged)
            logger.info(f"[RAG_SYSTEM_INIT] Downloading from Google Drive: {GDRIVE_FOLDER_ID_OR_URL}")
            
            if os.path.isdir(RAG_SOURCES_DIR):
                logger.info(f"[RAG_SYSTEM_INIT] Clearing existing contents of {RAG_SOURCES_DIR}")
                try:
                    for item_name in os.listdir(RAG_SOURCES_DIR):
                        item_path = os.path.join(RAG_SOURCES_DIR, item_name)
                        if os.path.isfile(item_path) or os.path.islink(item_path):
                            os.unlink(item_path)
                        elif os.path.isdir(item_path):
                            shutil.rmtree(item_path)
                    logger.info(f"[RAG_SYSTEM_INIT] Successfully cleared {RAG_SOURCES_DIR}")
                except Exception as e_clear:
                    logger.error(f"[RAG_SYSTEM_INIT] Could not clear {RAG_SOURCES_DIR}: {e_clear}")
            
            download_successful = download_and_unzip_gdrive_folder(GDRIVE_FOLDER_ID_OR_URL, RAG_SOURCES_DIR)
            if download_successful:
                logger.info(f"[RAG_SYSTEM_INIT] Successfully populated sources from Google Drive")
            else:
                logger.error("[RAG_SYSTEM_INIT] Failed to download sources from Google Drive")
        else:
            logger.warning("[RAG_SYSTEM_INIT] GDRIVE_SOURCES_ENABLED is True but GDRIVE_FOLDER_URL not set")
    elif not source_dir_override:
        logger.info("[RAG_SYSTEM_INIT] Google Drive sources download is DISABLED")

    faiss_index_actual_path = os.path.join(RAG_STORAGE_PARENT_DIR, RAG_FAISS_INDEX_SUBDIR_NAME)
    processed_files_metadata_path = os.path.join(faiss_index_actual_path, "processed_files.json")

    if force_rebuild:
        logger.info(f"[RAG_SYSTEM_INIT] Force rebuild: Deleting existing FAISS index at '{faiss_index_actual_path}'")
        if os.path.exists(faiss_index_actual_path):
            try:
                shutil.rmtree(faiss_index_actual_path)
                logger.info(f"[RAG_SYSTEM_INIT] Deleted existing FAISS index")
            except Exception as e_del:
                logger.error(f"[RAG_SYSTEM_INIT] Could not delete existing FAISS index: {e_del}", exc_info=True)

    try:
        logger.info("[RAG_SYSTEM_INIT] Creating KnowledgeRAG instance...")
        current_rag_instance = KnowledgeRAG(
            index_storage_dir=RAG_STORAGE_PARENT_DIR, 
            embedding_model_name=RAG_EMBEDDING_MODEL_NAME,
            groq_model_name_for_rag=RAG_LLM_MODEL_NAME,
            use_gpu_for_embeddings=RAG_EMBEDDING_USE_GPU,
            groq_api_key_for_rag=GROQ_API_KEY, 
            temperature=RAG_LLM_TEMPERATURE,
            chunk_size=RAG_CHUNK_SIZE,
            chunk_overlap=RAG_CHUNK_OVERLAP,
            reranker_model_name=RAG_RERANKER_MODEL_NAME,
            enable_reranker=RAG_RERANKER_ENABLED,
        )

        operation_successful = False
        if RAG_LOAD_INDEX_ON_STARTUP and not force_rebuild:
            logger.info(f"[RAG_SYSTEM_INIT] Attempting to load index from disk")
            try:
                current_rag_instance.load_index_from_disk()
                operation_successful = True
                logger.info(f"[RAG_SYSTEM_INIT] Index loaded successfully from: {faiss_index_actual_path}")
            except FileNotFoundError:
                logger.warning(f"[RAG_SYSTEM_INIT] Pre-built index not found. Will build from source files")
            except Exception as e_load:
                logger.error(f"[RAG_SYSTEM_INIT] Error loading index: {e_load}. Will build from source files", exc_info=True)

        if not operation_successful:
            logger.info(f"[RAG_SYSTEM_INIT] Building new index from source data in '{source_dir_to_use}'") # MODIFIED: Use correct dir
            try:
                pre_chunked_path = os.path.join(RAG_STORAGE_PARENT_DIR, RAG_CHUNKED_SOURCES_FILENAME)
                if not os.path.exists(pre_chunked_path) and (not os.path.isdir(source_dir_to_use) or not os.listdir(source_dir_to_use)): # MODIFIED: Use correct dir
                    logger.error(f"[RAG_SYSTEM_INIT] Neither pre-chunked JSON nor raw source files found")
                    os.makedirs(faiss_index_actual_path, exist_ok=True)
                    with open(os.path.join(faiss_index_actual_path, "index.faiss"), "w") as f_dummy: f_dummy.write("")
                    with open(os.path.join(faiss_index_actual_path, "index.pkl"), "w") as f_dummy: f_dummy.write("")
                    logger.info("[RAG_SYSTEM_INIT] Created dummy index files")
                    current_rag_instance.processed_source_files = ["No source files found to build index."]
                    raise FileNotFoundError(f"Sources directory '{source_dir_to_use}' is empty") # MODIFIED: Use correct dir

                current_rag_instance.build_index_from_source_files(
                    source_folder_path=source_dir_to_use # MODIFIED: Use correct dir
                )
                os.makedirs(faiss_index_actual_path, exist_ok=True) 
                with open(processed_files_metadata_path, 'w') as f:
                    json.dump(current_rag_instance.processed_source_files, f)

                operation_successful = True
                logger.info(f"[RAG_SYSTEM_INIT] Index built successfully from source data")
            except FileNotFoundError as e_fnf: 
                logger.critical(f"[RAG_SYSTEM_INIT] FATAL: No source data found: {e_fnf}", exc_info=False)
                return None 
            except ValueError as e_val: 
                logger.critical(f"[RAG_SYSTEM_INIT] FATAL: No processable documents found: {e_val}", exc_info=False)
                return None
            except Exception as e_build:
                logger.critical(f"[RAG_SYSTEM_INIT] FATAL: Failed to build FAISS index: {e_build}", exc_info=True)
                return None

        if operation_successful and current_rag_instance.vector_store:
            logger.info("[RAG_SYSTEM_INIT] ========== RAG System Initialized Successfully ==========")
            return current_rag_instance
        else:
            logger.error("[RAG_SYSTEM_INIT] Index was neither loaded nor built successfully")
            return None

    except Exception as e_init_components:
        logger.critical(f"[RAG_SYSTEM_INIT] FATAL: Failed to initialize RAG system components: {e_init_components}", exc_info=True)
        return None