import asyncio import json import platform import jwt from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from context_loader import load_context from llm_client import complete from session_logger import log_turn from session_store import load as load_session, save as save_session, list_all, generate_session_id, delete as delete_session, rename as rename_session, get_name as get_session_name from config import settings from persona import set_context, validate as validate_persona from auth_utils import COOKIE_NAME, decode_token import model_registry import event_bus from model_registry import get_role_config router = APIRouter() def _backend_label(backend: str, username: str, role: str = "chat") -> str: """Human-readable label for the model that handled a request (legacy path).""" if backend == "claude": return "Claude" if backend == "gemini": return "Gemini" if backend == "local": cfg = model_registry.get_best_local_model(username, role) if cfg: return cfg.get("label") or cfg.get("model_name") or "Local" return "Local" return backend.title() def _role_model_label(username: str, role: str, actual_backend: str) -> str: """Return the model label for a role, falling back to the generic backend label.""" cfg = model_registry.get_model_for_role(username, role) if cfg: return cfg.get("label") or cfg.get("model_name") or _backend_label(actual_backend, username, role) return _backend_label(actual_backend, username, role) class ChatRequest(BaseModel): message: str session_id: str | None = None tier: int | None = None model: str | None = None # legacy backend override ("claude"|"gemini"|"local") slot: str | None = None # Phase 3: explicit slot ("primary"|"backup_1"|"backup_2") chat_role: str = "chat" # active role: "chat"|"coder"|"research"|"distill" etc. include_long: bool = True include_mid: bool = True include_short: bool = True off_record: bool = False # skip session log (in-memory context preserved) user: str = "scott" persona: str = "inara" class BackendRequest(BaseModel): primary: str # "claude", "gemini", or "local" class NoteRequest(BaseModel): session_id: str note: str class HistoryUpdate(BaseModel): messages: list[dict] async def _stream_chat(req: ChatRequest): """ SSE generator: sends keepalive events every 3s while the LLM works, then sends the final response. Keeps the browser connection alive regardless of how long the backend takes. Event types: data: {"type": "keepalive"} data: {"type": "response", "response": "...", "session_id": "...", "backend": "...", "fallback_used": bool} data: {"type": "error", "message": "..."} """ try: user, persona = validate_persona(req.user, req.persona) set_context(user, persona) except ValueError as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return session_id = req.session_id or generate_session_id() tier = req.tier or settings.default_tier role_cfg = get_role_config(user, req.chat_role) system_prompt = load_context( tier, include_long=req.include_long, include_mid=req.include_mid, include_short=req.include_short, inject_datetime=role_cfg.get("inject_datetime", True), inject_mode=role_cfg.get("inject_mode", True), mode="otr" if req.off_record else "chat", ) history = load_session(session_id) history.append({"role": "user", "content": req.message, "off_record": req.off_record}) task = asyncio.create_task(complete( system_prompt=system_prompt, messages=history, model=req.model, role=req.chat_role, slot=req.slot, )) try: # Ping the browser every 3s so it doesn't drop the connection while not task.done(): yield 'data: {"type":"keepalive"}\n\n' try: await asyncio.wait_for(asyncio.shield(task), timeout=3) except asyncio.TimeoutError: pass except Exception: break try: response_text, actual_backend = task.result() if req.slot: slot_cfg = model_registry.get_model_for_slot(user, req.chat_role, req.slot) backend_label = (slot_cfg or {}).get("label") or _role_model_label(user, req.chat_role, actual_backend) else: backend_label = _role_model_label(user, req.chat_role, actual_backend) host = platform.node() history.append({ "role": "assistant", "content": response_text, "backend": actual_backend, "backend_label": backend_label, "host": host, "off_record": req.off_record, }) save_session(session_id, history) if not req.off_record: log_turn(session_id, req.message, response_text, backend_label, host) # fallback_used only makes sense for explicit backend selections. # In auto mode (req.model is None), just report what responded. fallback_used = bool(req.model and actual_backend != req.model) payload = { "type": "response", "response": response_text, "session_id": session_id, "backend": actual_backend, "backend_label": backend_label, "host": host, "fallback_used": fallback_used, } yield f"data: {json.dumps(payload)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" finally: # Ensure the LLM task is cancelled if the generator is torn down # (e.g. client disconnect or server shutdown). This propagates # CancelledError into _gemini() which kills the process group. if not task.done(): task.cancel() try: await task except (asyncio.CancelledError, Exception): pass @router.post("/chat") async def chat(req: ChatRequest) -> StreamingResponse: return StreamingResponse( _stream_chat(req), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) _BACKEND_CYCLE = ("claude", "gemini", "local") _BACKEND_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"} def _request_user(request: Request) -> str | None: """Extract username from JWT cookie, or None.""" try: token = request.cookies.get(COOKIE_NAME) return decode_token(token) if token else None except (jwt.InvalidTokenError, Exception): return None def _local_model_info(request: Request) -> dict | None: """Return the best local model {label, model_name} for the session user, or None.""" username = _request_user(request) if not username: return None try: cfg = model_registry.get_best_local_model(username, "chat") if cfg: return {"label": cfg.get("label", ""), "model_name": cfg.get("model_name", "")} except Exception: pass return None def _chat_slot_models(username: str) -> list[dict]: """Return [{slot, label, type}] for each configured slot in the chat role, primary first.""" registry = model_registry.get_registry(username) role_slots = registry.get("roles", {}).get("chat", {}) result = [] for slot_key in model_registry.PRIORITY_KEYS: model_id = role_slots.get(slot_key) if not model_id: continue resolved = model_registry._resolve_model(registry, model_id) if resolved: result.append({ "slot": slot_key, "label": resolved.get("label") or resolved.get("model_name") or "", "type": resolved.get("type", ""), }) return result def _available_roles_for_toggle(username: str) -> list[dict]: """Return roles with a primary model assigned (excluding orchestrator) for the UI toggle. Returns [{role, label, model_label, type}] ordered by settings.defined_roles. """ registry = model_registry.get_registry(username) roles_cfg = registry.get("roles", {}) result = [] for role_name in settings.get_defined_roles(): if role_name == "orchestrator": continue primary_id = roles_cfg.get(role_name, {}).get("primary") if not primary_id: continue resolved = model_registry._resolve_model(registry, primary_id) if resolved: result.append({ "role": role_name, "label": role_name.title(), "model_label": resolved.get("label") or resolved.get("model_name") or "", "type": resolved.get("type", ""), }) return result @router.get("/backend") async def get_backend(request: Request) -> dict: username = _request_user(request) chat_models = _chat_slot_models(username) if username else [] available_roles = _available_roles_for_toggle(username) if username else [] p = settings.primary_backend orch_label = None if username: orch_cfg = model_registry.get_model_for_role(username, "orchestrator") if orch_cfg: orch_label = orch_cfg.get("label") or orch_cfg.get("model_name") or None return { "chat_models": chat_models, # Phase 3: [{slot, label, type}] for chat-role slots "available_roles": available_roles, # kept for banner + backward compat "orchestrator_model": orch_label, # Legacy fields kept for backward compat "primary": p, "fallback": _BACKEND_FALLBACK.get(p, "claude"), "local_model": _local_model_info(request), } @router.post("/backend") async def set_backend(req: BackendRequest, request: Request) -> dict: if req.primary not in _BACKEND_CYCLE: raise HTTPException(status_code=400, detail="primary must be 'claude', 'gemini', or 'local'") settings.primary_backend = req.primary return { "primary": req.primary, "fallback": _BACKEND_FALLBACK[req.primary], "local_model": _local_model_info(request), } def _set_ctx(user: str, persona: str) -> None: """Validate and set persona context from query params. Raises HTTPException on bad input.""" try: u, p = validate_persona(user, persona) set_context(u, p) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @router.get("/history/{session_id}") async def get_history( session_id: str, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) name = get_session_name(session_id) return {"session_id": session_id, "name": name, "messages": load_session(session_id)} @router.get("/sessions") async def list_sessions( user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) return {"sessions": list_all()} class SessionRename(BaseModel): name: str @router.patch("/sessions/{session_id}") async def rename_session_endpoint( session_id: str, req: SessionRename, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) found = rename_session(session_id, req.name.strip()) if not found: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") return {"ok": True, "session_id": session_id, "name": req.name.strip()} @router.post("/api/sessions/backfill-names") async def backfill_session_names( request: Request, user: str = Query(""), persona: str = Query(""), ) -> dict: """Name every unnamed session using its first user message (truncated to 60 chars). Idempotent — only touches sessions that have no name set. user/persona default to the JWT session user + last-used persona cookie.""" # Resolve user from JWT if not provided if not user: token = request.cookies.get(COOKIE_NAME) if not token: raise HTTPException(status_code=401, detail="Not authenticated") try: user = decode_token(token) except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid session") # Resolve persona from cookie if not provided if not persona: from persona import list_user_personas persona_cookie = request.cookies.get("cx_last_persona", "") available = list_user_personas(user) persona = persona_cookie if persona_cookie in available else (available[0] if available else "") if not persona: raise HTTPException(status_code=400, detail="No persona found for user") _set_ctx(user, persona) sessions = list_all() named = 0 for s in sessions: if s.get("name"): continue messages = load_session(s["session_id"]) first_user = next((m for m in messages if m.get("role") == "user"), None) if not first_user: continue text = (first_user.get("content") or "").strip() if not text: continue auto_name = text[:60].rstrip() + ("…" if len(text) > 60 else "") rename_session(s["session_id"], auto_name) named += 1 return {"ok": True, "named": named, "total": len(sessions)} @router.delete("/sessions/{session_id}") async def delete_session_endpoint( session_id: str, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: _set_ctx(user, persona) found = delete_session(session_id) if not found: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") return {"ok": True, "session_id": session_id} @router.put("/history/{session_id}") async def replace_history( session_id: str, req: HistoryUpdate, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: """Replace the full message list for a session (used by edit/delete UI).""" _set_ctx(user, persona) save_session(session_id, req.messages) return {"ok": True, "session_id": session_id} @router.get("/events") async def sse_events() -> StreamingResponse: """Server-sent events stream — pushes real-time Talk activity to the browser.""" async def stream(): q = event_bus.subscribe() try: while True: try: event = await asyncio.wait_for(q.get(), timeout=20) yield f"data: {json.dumps(event)}\n\n" except asyncio.TimeoutError: yield 'data: {"type":"keepalive"}\n\n' except (GeneratorExit, asyncio.CancelledError): pass finally: event_bus.unsubscribe(q) return StreamingResponse( stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/note") async def add_note( req: NoteRequest, user: str = Query("scott"), persona: str = Query("inara"), ) -> dict: """Inject a public note into session history so the LLM sees it next turn.""" _set_ctx(user, persona) history = load_session(req.session_id) history.append({"role": "user", "content": f"[NOTE] {req.note}"}) save_session(req.session_id, history) return {"ok": True, "session_id": req.session_id}