-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·125 lines (99 loc) · 3.81 KB
/
main.py
File metadata and controls
executable file
·125 lines (99 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Main entry point for the RAG system.
"""
# Server setup libraries
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from pydantic import BaseModel
# Configuration and utility libraries
from src.embeddings.models import StellaEmbeddings
from src.generation.llm import generate_response
from src.services.rag_pipeline import create_prompt, initialize_database, retrieve_contexts
from src.utils.helpers import parse_arguments
# Authentication
from config import HUGGING_FACE_ACCESS_TOKEN, DEBUG
from huggingface_hub import login
import logging
# Define Colored Formatter for Logs
class ColoredFormatter(logging.Formatter):
COLORS = {
logging.DEBUG: "\033[94m", # Blue
logging.INFO: "\033[92m", # Green
logging.WARNING: "\033[93m", # Yellow
logging.ERROR: "\033[91m", # Red
logging.CRITICAL: "\033[95m" # Magenta
}
RESET = "\033[0m"
def format(self, record):
log_message = super().format(record)
color = self.COLORS.get(record.levelno, self.RESET)
return f"{color}{log_message}{self.RESET}"
# Configure Root Logger
log_level = logging.DEBUG if DEBUG else logging.INFO
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logging.basicConfig(level=log_level, handlers=[handler])
# Suppress noisy external loggers
for logger_name in ["filelock", "urllib3", "huggingface_hub", "sentence_transformers", "chromadb", "httpx", "httpcore", "fsspec", "tzlocal"]:
logging.getLogger(logger_name).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# This holds the model in RAM so we don't reload it
global_embedding_model = None
@asynccontextmanager
async def lifespan(_: FastAPI):
"""
Executes once when the server starts.KEN:
"""
try:
global global_embedding_model
global_embedding_model = StellaEmbeddings()
login(token=HUGGING_FACE_ACCESS_TOKEN)
initialize_database(global_embedding_model)
except Exception as e:
raise RuntimeError(f"Failed to load model during startup: {e}")
yield # Server runs here
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Query(BaseModel):
prompt: str
@app.websocket("/query")
async def handle_query_websocket(websocket: WebSocket):
await websocket.accept()
if global_embedding_model is None:
await websocket.send_json({"detail": "Server is still warming up. Try again in 5s.", "status": 503})
await websocket.close()
return
try:
while True:
data = await websocket.receive_json()
query = Query(**data)
# Pass the PRE-LOADED global model
response, _ = orchestratePipeline(query.prompt, global_embedding_model)
await websocket.send_json(
{"response": response, "status": 200}
)
except WebSocketDisconnect:
logger.info("WebSocket disconnected.")
except Exception as e:
logger.error(f"An error occurred: {e}")
await websocket.send_json({"detail": str(e), "status": 500})
await websocket.close()
def orchestratePipeline(query_text, embedding_model):
"""
Main function. Note: It now accepts the 'embedding_model' as an argument.
"""
# 1. Retrieve (Uses pre-loaded model)
contexts = retrieve_contexts(query_text, embedding_model)
if not contexts:
logger.warning(f"Unable to find matching results for '{query_text}'")
# 2. Generate
prompt = create_prompt(query_text, contexts)
response = generate_response(prompt)
return response, contexts