Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions codewiki/src/fe/chat_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3
"""
FastAPI chat endpoints for interactive Q&A over generated documentation.

Endpoint:
POST /api/chat/{job_id}
Request body: {"message": "...", "session_id": "..."}
Response: text/event-stream (Server-Sent Events)

Chat history is kept **in-memory** per (job_id, session_id) pair. It
persists for the lifetime of the server process, which satisfies the
requirement that "chat history persists for the session".
"""

import re
from pathlib import Path
from typing import Dict, List

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from .chat_service import stream_chat_response

router = APIRouter()

# In-memory chat history: {(job_id, session_id): [{"role": ..., "content": ...}]}
_chat_histories: Dict[tuple, List[Dict[str, str]]] = {}

# Docs root – resolved once from the environment / config so all requests use
# the same base directory. Can be overridden for testing.
_DOCS_ROOT = Path("output") / "docs"

# Job IDs are derived from GitHub repo full-names like "owner--repo".
# Only allow alphanumeric chars, hyphens, underscores, and dots.
_JOB_ID_RE = re.compile(r"^[A-Za-z0-9._-]+$")


class ChatRequest(BaseModel):
"""Request body for POST /api/chat/{job_id}."""
message: str
session_id: str = "default"


def _validate_job_id(job_id: str) -> None:
"""Raise 400 if *job_id* contains path-traversal or unexpected characters."""
if not _JOB_ID_RE.match(job_id) or ".." in job_id:
raise HTTPException(status_code=400, detail="Invalid job ID")


def _get_history(job_id: str, session_id: str) -> List[Dict[str, str]]:
key = (job_id, session_id)
if key not in _chat_histories:
_chat_histories[key] = []
return _chat_histories[key]


def _append_history(
job_id: str,
session_id: str,
role: str,
content: str,
) -> None:
_get_history(job_id, session_id).append({"role": role, "content": content})


@router.post("/api/chat/{job_id}")
async def chat(job_id: str, body: ChatRequest):
"""
Stream an LLM response to *body.message* using RAG over the docs for
*job_id*.

The response is a ``text/event-stream`` stream. Each event carries one or
more tokens of the assistant's reply. A final ``data: [DONE]`` event
signals the end of the stream.
"""
_validate_job_id(job_id)

# Resolve and canonicalise the docs path; reject any traversal attempts.
docs_path = (_DOCS_ROOT / f"{job_id}-docs").resolve()
expected_root = _DOCS_ROOT.resolve()
if not str(docs_path).startswith(str(expected_root)):
raise HTTPException(status_code=400, detail="Invalid job ID")
if not docs_path.exists():
raise HTTPException(status_code=404, detail="Documentation not found for this job")

message = body.message.strip()
if not message:
raise HTTPException(status_code=422, detail="Message must not be empty")

# Retrieve or create history for this session
history = _get_history(job_id, body.session_id)

# Collect assistant reply so we can persist it to history afterwards
assistant_reply_parts: List[str] = []

async def _event_generator():
async for token in stream_chat_response(job_id, docs_path, message, history):
if token == "[DONE]":
# Persist turn to history now that streaming has finished
full_reply = "".join(assistant_reply_parts)
_append_history(job_id, body.session_id, "user", message)
_append_history(job_id, body.session_id, "assistant", full_reply)
yield "data: [DONE]\n\n"
else:
assistant_reply_parts.append(token)
# Escape newlines inside the data field per SSE spec
safe_token = token.replace("\n", "\\n")
yield f"data: {safe_token}\n\n"

return StreamingResponse(_event_generator(), media_type="text/event-stream")


@router.delete("/api/chat/{job_id}/history")
async def clear_history(job_id: str, session_id: str = "default"):
"""Clear chat history for *job_id* / *session_id*."""
_validate_job_id(job_id)
key = (job_id, session_id)
_chat_histories.pop(key, None)
return {"status": "cleared"}
236 changes: 236 additions & 0 deletions codewiki/src/fe/chat_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
#!/usr/bin/env python3
"""
RAG-based chat service for interactive Q&A over generated documentation.

Pipeline:
1. Load all .md files from the job's docs directory
2. Split content into paragraphs (chunks)
3. Retrieve the most relevant chunks via keyword-overlap scoring (BM25-inspired)
4. Build a context-aware prompt and stream the LLM response via SSE
"""

import re
import os
import math
from pathlib import Path
from typing import AsyncGenerator, List, Tuple, Dict


# ---------------------------------------------------------------------------
# Document loading & chunking
# ---------------------------------------------------------------------------

def _load_docs(docs_path: Path) -> List[Tuple[str, str]]:
"""
Load all markdown files from *docs_path*.

Returns a list of ``(filename, chunk_text)`` pairs where each chunk is a
non-empty paragraph from the file.
"""
chunks: List[Tuple[str, str]] = []
for md_file in sorted(docs_path.glob("*.md")):
try:
content = md_file.read_text(encoding="utf-8")
except Exception:
continue
for para in _split_into_chunks(content):
chunks.append((md_file.name, para))
return chunks


def _split_into_chunks(text: str, max_chars: int = 1500) -> List[str]:
"""
Split *text* into chunks of at most *max_chars* characters.

First splits on blank lines (paragraph boundaries), then further splits
any paragraph that exceeds *max_chars* by sentence boundaries.
"""
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
chunks: List[str] = []
for para in paragraphs:
if len(para) <= max_chars:
chunks.append(para)
else:
# Split long paragraphs by sentence
sentences = re.split(r"(?<=[.!?])\s+", para)
current = ""
for sentence in sentences:
if len(current) + len(sentence) + 1 <= max_chars:
current = (current + " " + sentence).strip()
else:
if current:
chunks.append(current)
current = sentence
if current:
chunks.append(current)
return chunks


# ---------------------------------------------------------------------------
# Retrieval (keyword BM25-inspired scoring, no external dependencies)
# ---------------------------------------------------------------------------

def _tokenize(text: str) -> List[str]:
"""Lowercase word-tokenise *text*, removing punctuation."""
return re.findall(r"\b[a-z0-9_]+\b", text.lower())


def _build_idf(chunks: List[Tuple[str, str]]) -> Dict[str, float]:
"""Compute inverse document frequency for each token across all chunks."""
n = len(chunks)
if n == 0:
return {}
df: Dict[str, int] = {}
for _, text in chunks:
for token in set(_tokenize(text)):
df[token] = df.get(token, 0) + 1
return {token: math.log((n - freq + 0.5) / (freq + 0.5) + 1.0) for token, freq in df.items()}


def _bm25_score(
query_tokens: List[str],
doc_tokens: List[str],
idf: Dict[str, float],
k1: float = 1.5,
b: float = 0.75,
avg_dl: float = 100.0,
) -> float:
"""Compute BM25 score for a single document."""
dl = len(doc_tokens)
tf: Dict[str, int] = {}
for t in doc_tokens:
tf[t] = tf.get(t, 0) + 1

score = 0.0
for token in query_tokens:
if token not in idf:
continue
f = tf.get(token, 0)
numerator = f * (k1 + 1)
denominator = f + k1 * (1 - b + b * dl / max(avg_dl, 1))
score += idf[token] * numerator / max(denominator, 1e-9)
return score


def retrieve_relevant_chunks(
query: str,
chunks: List[Tuple[str, str]],
top_k: int = 5,
) -> List[Tuple[str, str]]:
"""
Return the *top_k* most relevant ``(filename, chunk_text)`` pairs for *query*.

Uses a lightweight BM25-inspired scoring function so no external libraries
are required.
"""
if not chunks:
return []

idf = _build_idf(chunks)
query_tokens = _tokenize(query)
avg_dl = sum(len(_tokenize(text)) for _, text in chunks) / len(chunks)

scored: List[Tuple[float, int]] = []
for i, (_, text) in enumerate(chunks):
doc_tokens = _tokenize(text)
score = _bm25_score(query_tokens, doc_tokens, idf, avg_dl=avg_dl)
scored.append((score, i))

scored.sort(key=lambda x: x[0], reverse=True)
return [chunks[i] for _, i in scored[:top_k]]


# ---------------------------------------------------------------------------
# LLM streaming
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = """\
You are a helpful documentation assistant. You answer questions about a \
software repository based solely on the provided documentation excerpts.

Rules:
- Only use information from the provided excerpts.
- When citing information, mention the source file in parentheses, \
e.g. (overview.md).
- If the answer is not found in the excerpts, say so clearly.
- Format code examples with markdown code blocks.
"""


def _build_context(relevant_chunks: List[Tuple[str, str]]) -> str:
"""Format retrieved chunks as a context block for the LLM prompt."""
if not relevant_chunks:
return "(No relevant documentation excerpts found.)"
parts: List[str] = []
for filename, text in relevant_chunks:
parts.append(f"--- {filename} ---\n{text}")
return "\n\n".join(parts)


async def stream_chat_response(
job_id: str,
docs_path: Path,
message: str,
history: List[Dict[str, str]],
) -> AsyncGenerator[str, None]:
"""
Async generator that yields LLM response tokens as Server-Sent Event data.

Args:
job_id: Documentation job identifier (unused in the call itself but
useful for logging / future caching).
docs_path: Path to the directory containing the generated .md files.
message: User's natural-language question.
history: List of previous ``{"role": ..., "content": ...}`` messages
in the current session.

Yields:
Strings of the form ``data: <token>\n\n`` for each streamed token,
followed by a final ``data: [DONE]\n\n`` sentinel.
"""
from openai import AsyncOpenAI, APIError

# Load and retrieve relevant chunks
chunks = _load_docs(docs_path)
relevant = retrieve_relevant_chunks(message, chunks, top_k=6)
context = _build_context(relevant)

# Build messages list
messages: List[Dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}]
# Include up to the last 10 turns of history to stay within token budget
messages.extend(history[-10:])
messages.append(
{
"role": "user",
"content": (
f"Documentation excerpts:\n\n{context}\n\n"
f"Question: {message}"
),
}
)

# Read LLM config from environment (same variables used by the web app and
# the CLI; falls back to the values defined in codewiki.src.config).
llm_base_url = os.getenv("LLM_BASE_URL", "http://localhost:4000/")
llm_api_key = os.getenv("LLM_API_KEY", "sk-1234")
main_model = os.getenv("MAIN_MODEL", "claude-sonnet-4")

client = AsyncOpenAI(base_url=llm_base_url, api_key=llm_api_key)

try:
stream = await client.chat.completions.create(
model=main_model,
messages=messages, # type: ignore[arg-type]
stream=True,
temperature=0.3,
max_tokens=2048,
)
async for chunk in stream:
delta = chunk.choices[0].delta if chunk.choices else None
if delta and delta.content:
yield delta.content
except APIError:
# Avoid leaking internal error details (e.g. API keys in URLs) to the client.
yield "\n\n⚠️ The LLM service returned an error. Please try again later."
finally:
yield "[DONE]"
Loading