import asyncio import json import uuid from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from context_loader import load_context from llm_client import complete from session_logger import log_turn from session_store import load as load_session, save as save_session, list_all from config import settings router = APIRouter() class ChatRequest(BaseModel): message: str session_id: str | None = None tier: int | None = None model: str | None = None # "claude" or "gemini" to override; None = use primary_backend class BackendRequest(BaseModel): primary: str # "claude" or "gemini" class NoteRequest(BaseModel): session_id: str note: str async def _stream_chat(req: ChatRequest): """ SSE generator: sends keepalive events every 3s while the LLM works, then sends the final response. Keeps the browser connection alive regardless of how long the backend takes. Event types: data: {"type": "keepalive"} data: {"type": "response", "response": "...", "session_id": "...", "backend": "...", "fallback_used": bool} data: {"type": "error", "message": "..."} """ session_id = req.session_id or str(uuid.uuid4())[:8] tier = req.tier or settings.default_tier system_prompt = load_context(tier) history = load_session(session_id) history.append({"role": "user", "content": req.message}) task = asyncio.create_task(complete( system_prompt=system_prompt, messages=history, model=req.model, )) try: # Ping the browser every 3s so it doesn't drop the connection while not task.done(): yield 'data: {"type":"keepalive"}\n\n' try: await asyncio.wait_for(asyncio.shield(task), timeout=3) except asyncio.TimeoutError: pass except Exception: break try: response_text, actual_backend = task.result() history.append({"role": "assistant", "content": response_text}) save_session(session_id, history) log_turn(session_id, req.message, response_text) requested = req.model or settings.primary_backend payload = { "type": "response", "response": response_text, "session_id": session_id, "backend": actual_backend, "fallback_used": actual_backend != requested, } yield f"data: {json.dumps(payload)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" finally: # Ensure the LLM task is cancelled if the generator is torn down # (e.g. client disconnect or server shutdown). This propagates # CancelledError into _gemini() which kills the process group. if not task.done(): task.cancel() try: await task except (asyncio.CancelledError, Exception): pass @router.post("/chat") async def chat(req: ChatRequest) -> StreamingResponse: return StreamingResponse( _stream_chat(req), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.get("/backend") async def get_backend() -> dict: other = "gemini" if settings.primary_backend == "claude" else "claude" return {"primary": settings.primary_backend, "fallback": other} @router.post("/backend") async def set_backend(req: BackendRequest) -> dict: if req.primary not in ("claude", "gemini"): raise HTTPException(status_code=400, detail="primary must be 'claude' or 'gemini'") settings.primary_backend = req.primary other = "gemini" if req.primary == "claude" else "claude" return {"primary": settings.primary_backend, "fallback": other} @router.get("/history/{session_id}") async def get_history(session_id: str) -> dict: return {"session_id": session_id, "messages": load_session(session_id)} @router.get("/sessions") async def list_sessions() -> dict: return {"sessions": list_all()} @router.post("/note") async def add_note(req: NoteRequest) -> dict: """Inject a public note into session history so the LLM sees it next turn.""" history = load_session(req.session_id) history.append({"role": "user", "content": f"[NOTE] {req.note}"}) save_session(req.session_id, history) return {"ok": True, "session_id": req.session_id}