import asyncio import json import jwt from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from context_loader import load_context from llm_client import complete from session_logger import log_turn from session_store import load as load_session, save as save_session, list_all, generate_session_id, delete as delete_session, rename as rename_session from config import settings from persona import set_context, validate as validate_persona from auth_utils import COOKIE_NAME, decode_token import user_settings import event_bus router = APIRouter() class ChatRequest(BaseModel): message: str session_id: str | None = None tier: int | None = None model: str | None = None # "claude" or "gemini" to override; None = use primary_backend include_long: bool = True include_mid: bool = True include_short: bool = True off_record: bool = False # skip session log (in-memory context preserved) user: str = "scott" persona: str = "inara" class BackendRequest(BaseModel): primary: str # "claude", "gemini", or "local" class NoteRequest(BaseModel): session_id: str note: str class HistoryUpdate(BaseModel): messages: list[dict] async def _stream_chat(req: ChatRequest): """ SSE generator: sends keepalive events every 3s while the LLM works, then sends the final response. Keeps the browser connection alive regardless of how long the backend takes. Event types: data: {"type": "keepalive"} data: {"type": "response", "response": "...", "session_id": "...", "backend": "...", "fallback_used": bool} data: {"type": "error", "message": "..."} """ try: user, persona = validate_persona(req.user, req.persona) set_context(user, persona) except ValueError as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return session_id = req.session_id or generate_session_id() tier = req.tier or settings.default_tier system_prompt = load_context( tier, include_long=req.include_long, include_mid=req.include_mid, include_short=req.include_short, ) history = load_session(session_id) history.append({"role": "user", "content": req.message}) task = asyncio.create_task(complete( system_prompt=system_prompt, messages=history, model=req.model, )) try: # Ping the browser every 3s so it doesn't drop the connection while not task.done(): yield 'data: {"type":"keepalive"}\n\n' try: await asyncio.wait_for(asyncio.shield(task), timeout=3) except asyncio.TimeoutError: pass except Exception: break try: response_text, actual_backend = task.result() history.append({"role": "assistant", "content": response_text}) save_session(session_id, history) if not req.off_record: log_turn(session_id, req.message, response_text) requested = req.model or settings.primary_backend payload = { "type": "response", "response": response_text, "session_id": session_id, "backend": actual_backend, "fallback_used": actual_backend != requested, } yield f"data: {json.dumps(payload)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" finally: # Ensure the LLM task is cancelled if the generator is torn down # (e.g. client disconnect or server shutdown). This propagates # CancelledError into _gemini() which kills the process group. if not task.done(): task.cancel() try: await task except (asyncio.CancelledError, Exception): pass @router.post("/chat") async def chat(req: ChatRequest) -> StreamingResponse: return StreamingResponse( _stream_chat(req), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) _BACKEND_CYCLE = ("claude", "gemini", "local") _BACKEND_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"} def _local_model_info(request: Request) -> dict | None: """Return active local model {label, model_name} for the session user, or None.""" try: token = request.cookies.get(COOKIE_NAME) username = decode_token(token) if token else None if not username: return None cfg = user_settings.get_active_local_model(username) if cfg: return {"label": cfg["label"], "model_name": cfg["model_name"]} except (jwt.InvalidTokenError, Exception): pass return None @router.get("/backend") async def get_backend(request: Request) -> dict: p = settings.primary_backend return { "primary": p, "fallback": _BACKEND_FALLBACK.get(p, "claude"), "local_model": _local_model_info(request), } @router.post("/backend") async def set_backend(req: BackendRequest, request: Request) -> dict: if req.primary not in _BACKEND_CYCLE: raise HTTPException(status_code=400, detail="primary must be 'claude', 'gemini', or 'local'") settings.primary_backend = req.primary return { "primary": req.primary, "fallback": _BACKEND_FALLBACK[req.primary], "local_model": _local_model_info(request), } def _set_ctx(user: str, persona: str) -> None: """Validate and set persona context from query params. Raises HTTPException on bad input.""" try: u, p = validate_persona(user, persona) set_context(u, p) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @router.get("/history/{session_id}") async def get_history( session_id: str, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) return {"session_id": session_id, "messages": load_session(session_id)} @router.get("/sessions") async def list_sessions( user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) return {"sessions": list_all()} class SessionRename(BaseModel): name: str @router.patch("/sessions/{session_id}") async def rename_session_endpoint( session_id: str, req: SessionRename, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) found = rename_session(session_id, req.name.strip()) if not found: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") return {"ok": True, "session_id": session_id, "name": req.name.strip()} @router.delete("/sessions/{session_id}") async def delete_session_endpoint( session_id: str, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) found = delete_session(session_id) if not found: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") return {"ok": True, "session_id": session_id} @router.put("/history/{session_id}") async def replace_history( session_id: str, req: HistoryUpdate, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: """Replace the full message list for a session (used by edit/delete UI).""" _set_ctx(user, persona) save_session(session_id, req.messages) return {"ok": True, "session_id": session_id} @router.get("/events") async def sse_events() -> StreamingResponse: """Server-sent events stream — pushes real-time Talk activity to the browser.""" async def stream(): q = event_bus.subscribe() try: while True: try: event = await asyncio.wait_for(q.get(), timeout=20) yield f"data: {json.dumps(event)}\n\n" except asyncio.TimeoutError: yield 'data: {"type":"keepalive"}\n\n' except (GeneratorExit, asyncio.CancelledError): pass finally: event_bus.unsubscribe(q) return StreamingResponse( stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/note") async def add_note( req: NoteRequest, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: """Inject a public note into session history so the LLM sees it next turn.""" _set_ctx(user, persona) history = load_session(req.session_id) history.append({"role": "user", "content": f"[NOTE] {req.note}"}) save_session(req.session_id, history) return {"ok": True, "session_id": req.session_id}