diff --git a/codewiki/src/fe/chat_routes.py b/codewiki/src/fe/chat_routes.py new file mode 100644 index 00000000..2c2db3cd --- /dev/null +++ b/codewiki/src/fe/chat_routes.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +FastAPI chat endpoints for interactive Q&A over generated documentation. + +Endpoint: + POST /api/chat/{job_id} + Request body: {"message": "...", "session_id": "..."} + Response: text/event-stream (Server-Sent Events) + +Chat history is kept **in-memory** per (job_id, session_id) pair. It +persists for the lifetime of the server process, which satisfies the +requirement that "chat history persists for the session". +""" + +import re +from pathlib import Path +from typing import Dict, List + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from .chat_service import stream_chat_response + +router = APIRouter() + +# In-memory chat history: {(job_id, session_id): [{"role": ..., "content": ...}]} +_chat_histories: Dict[tuple, List[Dict[str, str]]] = {} + +# Docs root – resolved once from the environment / config so all requests use +# the same base directory. Can be overridden for testing. +_DOCS_ROOT = Path("output") / "docs" + +# Job IDs are derived from GitHub repo full-names like "owner--repo". +# Only allow alphanumeric chars, hyphens, underscores, and dots. +_JOB_ID_RE = re.compile(r"^[A-Za-z0-9._-]+$") + + +class ChatRequest(BaseModel): + """Request body for POST /api/chat/{job_id}.""" + message: str + session_id: str = "default" + + +def _validate_job_id(job_id: str) -> None: + """Raise 400 if *job_id* contains path-traversal or unexpected characters.""" + if not _JOB_ID_RE.match(job_id) or ".." in job_id: + raise HTTPException(status_code=400, detail="Invalid job ID") + + +def _get_history(job_id: str, session_id: str) -> List[Dict[str, str]]: + key = (job_id, session_id) + if key not in _chat_histories: + _chat_histories[key] = [] + return _chat_histories[key] + + +def _append_history( + job_id: str, + session_id: str, + role: str, + content: str, +) -> None: + _get_history(job_id, session_id).append({"role": role, "content": content}) + + +@router.post("/api/chat/{job_id}") +async def chat(job_id: str, body: ChatRequest): + """ + Stream an LLM response to *body.message* using RAG over the docs for + *job_id*. + + The response is a ``text/event-stream`` stream. Each event carries one or + more tokens of the assistant's reply. A final ``data: [DONE]`` event + signals the end of the stream. + """ + _validate_job_id(job_id) + + # Resolve and canonicalise the docs path; reject any traversal attempts. + docs_path = (_DOCS_ROOT / f"{job_id}-docs").resolve() + expected_root = _DOCS_ROOT.resolve() + if not str(docs_path).startswith(str(expected_root)): + raise HTTPException(status_code=400, detail="Invalid job ID") + if not docs_path.exists(): + raise HTTPException(status_code=404, detail="Documentation not found for this job") + + message = body.message.strip() + if not message: + raise HTTPException(status_code=422, detail="Message must not be empty") + + # Retrieve or create history for this session + history = _get_history(job_id, body.session_id) + + # Collect assistant reply so we can persist it to history afterwards + assistant_reply_parts: List[str] = [] + + async def _event_generator(): + async for token in stream_chat_response(job_id, docs_path, message, history): + if token == "[DONE]": + # Persist turn to history now that streaming has finished + full_reply = "".join(assistant_reply_parts) + _append_history(job_id, body.session_id, "user", message) + _append_history(job_id, body.session_id, "assistant", full_reply) + yield "data: [DONE]\n\n" + else: + assistant_reply_parts.append(token) + # Escape newlines inside the data field per SSE spec + safe_token = token.replace("\n", "\\n") + yield f"data: {safe_token}\n\n" + + return StreamingResponse(_event_generator(), media_type="text/event-stream") + + +@router.delete("/api/chat/{job_id}/history") +async def clear_history(job_id: str, session_id: str = "default"): + """Clear chat history for *job_id* / *session_id*.""" + _validate_job_id(job_id) + key = (job_id, session_id) + _chat_histories.pop(key, None) + return {"status": "cleared"} diff --git a/codewiki/src/fe/chat_service.py b/codewiki/src/fe/chat_service.py new file mode 100644 index 00000000..4a37d2c8 --- /dev/null +++ b/codewiki/src/fe/chat_service.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +RAG-based chat service for interactive Q&A over generated documentation. + +Pipeline: +1. Load all .md files from the job's docs directory +2. Split content into paragraphs (chunks) +3. Retrieve the most relevant chunks via keyword-overlap scoring (BM25-inspired) +4. Build a context-aware prompt and stream the LLM response via SSE +""" + +import re +import os +import math +from pathlib import Path +from typing import AsyncGenerator, List, Tuple, Dict + + +# --------------------------------------------------------------------------- +# Document loading & chunking +# --------------------------------------------------------------------------- + +def _load_docs(docs_path: Path) -> List[Tuple[str, str]]: + """ + Load all markdown files from *docs_path*. + + Returns a list of ``(filename, chunk_text)`` pairs where each chunk is a + non-empty paragraph from the file. + """ + chunks: List[Tuple[str, str]] = [] + for md_file in sorted(docs_path.glob("*.md")): + try: + content = md_file.read_text(encoding="utf-8") + except Exception: + continue + for para in _split_into_chunks(content): + chunks.append((md_file.name, para)) + return chunks + + +def _split_into_chunks(text: str, max_chars: int = 1500) -> List[str]: + """ + Split *text* into chunks of at most *max_chars* characters. + + First splits on blank lines (paragraph boundaries), then further splits + any paragraph that exceeds *max_chars* by sentence boundaries. + """ + paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()] + chunks: List[str] = [] + for para in paragraphs: + if len(para) <= max_chars: + chunks.append(para) + else: + # Split long paragraphs by sentence + sentences = re.split(r"(?<=[.!?])\s+", para) + current = "" + for sentence in sentences: + if len(current) + len(sentence) + 1 <= max_chars: + current = (current + " " + sentence).strip() + else: + if current: + chunks.append(current) + current = sentence + if current: + chunks.append(current) + return chunks + + +# --------------------------------------------------------------------------- +# Retrieval (keyword BM25-inspired scoring, no external dependencies) +# --------------------------------------------------------------------------- + +def _tokenize(text: str) -> List[str]: + """Lowercase word-tokenise *text*, removing punctuation.""" + return re.findall(r"\b[a-z0-9_]+\b", text.lower()) + + +def _build_idf(chunks: List[Tuple[str, str]]) -> Dict[str, float]: + """Compute inverse document frequency for each token across all chunks.""" + n = len(chunks) + if n == 0: + return {} + df: Dict[str, int] = {} + for _, text in chunks: + for token in set(_tokenize(text)): + df[token] = df.get(token, 0) + 1 + return {token: math.log((n - freq + 0.5) / (freq + 0.5) + 1.0) for token, freq in df.items()} + + +def _bm25_score( + query_tokens: List[str], + doc_tokens: List[str], + idf: Dict[str, float], + k1: float = 1.5, + b: float = 0.75, + avg_dl: float = 100.0, +) -> float: + """Compute BM25 score for a single document.""" + dl = len(doc_tokens) + tf: Dict[str, int] = {} + for t in doc_tokens: + tf[t] = tf.get(t, 0) + 1 + + score = 0.0 + for token in query_tokens: + if token not in idf: + continue + f = tf.get(token, 0) + numerator = f * (k1 + 1) + denominator = f + k1 * (1 - b + b * dl / max(avg_dl, 1)) + score += idf[token] * numerator / max(denominator, 1e-9) + return score + + +def retrieve_relevant_chunks( + query: str, + chunks: List[Tuple[str, str]], + top_k: int = 5, +) -> List[Tuple[str, str]]: + """ + Return the *top_k* most relevant ``(filename, chunk_text)`` pairs for *query*. + + Uses a lightweight BM25-inspired scoring function so no external libraries + are required. + """ + if not chunks: + return [] + + idf = _build_idf(chunks) + query_tokens = _tokenize(query) + avg_dl = sum(len(_tokenize(text)) for _, text in chunks) / len(chunks) + + scored: List[Tuple[float, int]] = [] + for i, (_, text) in enumerate(chunks): + doc_tokens = _tokenize(text) + score = _bm25_score(query_tokens, doc_tokens, idf, avg_dl=avg_dl) + scored.append((score, i)) + + scored.sort(key=lambda x: x[0], reverse=True) + return [chunks[i] for _, i in scored[:top_k]] + + +# --------------------------------------------------------------------------- +# LLM streaming +# --------------------------------------------------------------------------- + +SYSTEM_PROMPT = """\ +You are a helpful documentation assistant. You answer questions about a \ +software repository based solely on the provided documentation excerpts. + +Rules: +- Only use information from the provided excerpts. +- When citing information, mention the source file in parentheses, \ + e.g. (overview.md). +- If the answer is not found in the excerpts, say so clearly. +- Format code examples with markdown code blocks. +""" + + +def _build_context(relevant_chunks: List[Tuple[str, str]]) -> str: + """Format retrieved chunks as a context block for the LLM prompt.""" + if not relevant_chunks: + return "(No relevant documentation excerpts found.)" + parts: List[str] = [] + for filename, text in relevant_chunks: + parts.append(f"--- {filename} ---\n{text}") + return "\n\n".join(parts) + + +async def stream_chat_response( + job_id: str, + docs_path: Path, + message: str, + history: List[Dict[str, str]], +) -> AsyncGenerator[str, None]: + """ + Async generator that yields LLM response tokens as Server-Sent Event data. + + Args: + job_id: Documentation job identifier (unused in the call itself but + useful for logging / future caching). + docs_path: Path to the directory containing the generated .md files. + message: User's natural-language question. + history: List of previous ``{"role": ..., "content": ...}`` messages + in the current session. + + Yields: + Strings of the form ``data: \n\n`` for each streamed token, + followed by a final ``data: [DONE]\n\n`` sentinel. + """ + from openai import AsyncOpenAI, APIError + + # Load and retrieve relevant chunks + chunks = _load_docs(docs_path) + relevant = retrieve_relevant_chunks(message, chunks, top_k=6) + context = _build_context(relevant) + + # Build messages list + messages: List[Dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}] + # Include up to the last 10 turns of history to stay within token budget + messages.extend(history[-10:]) + messages.append( + { + "role": "user", + "content": ( + f"Documentation excerpts:\n\n{context}\n\n" + f"Question: {message}" + ), + } + ) + + # Read LLM config from environment (same variables used by the web app and + # the CLI; falls back to the values defined in codewiki.src.config). + llm_base_url = os.getenv("LLM_BASE_URL", "http://localhost:4000/") + llm_api_key = os.getenv("LLM_API_KEY", "sk-1234") + main_model = os.getenv("MAIN_MODEL", "claude-sonnet-4") + + client = AsyncOpenAI(base_url=llm_base_url, api_key=llm_api_key) + + try: + stream = await client.chat.completions.create( + model=main_model, + messages=messages, # type: ignore[arg-type] + stream=True, + temperature=0.3, + max_tokens=2048, + ) + async for chunk in stream: + delta = chunk.choices[0].delta if chunk.choices else None + if delta and delta.content: + yield delta.content + except APIError: + # Avoid leaking internal error details (e.g. API keys in URLs) to the client. + yield "\n\n⚠️ The LLM service returned an error. Please try again later." + finally: + yield "[DONE]" diff --git a/codewiki/src/fe/templates.py b/codewiki/src/fe/templates.py index 763562c7..ad11646f 100644 --- a/codewiki/src/fe/templates.py +++ b/codewiki/src/fe/templates.py @@ -675,6 +675,383 @@ mermaid.init(undefined, document.querySelectorAll('.mermaid')); }); + + + + +
+ + +
+ + """ \ No newline at end of file diff --git a/codewiki/src/fe/web_app.py b/codewiki/src/fe/web_app.py index 6f5d846b..84b92e10 100644 --- a/codewiki/src/fe/web_app.py +++ b/codewiki/src/fe/web_app.py @@ -18,6 +18,7 @@ from .background_worker import BackgroundWorker from .routes import WebRoutes from .config import WebAppConfig +from .chat_routes import router as chat_router # Initialize FastAPI app @@ -37,6 +38,9 @@ ) web_routes = WebRoutes(background_worker=background_worker, cache_manager=cache_manager) +# Register chat routes +app.include_router(chat_router) + # Register routes @app.get("/", response_class=HTMLResponse)