Files
Cortex-Inara/cortex/routers/chat.py
Scott Idem 8baab874f1 feat: replace backend/slot toggle with role selector
The backend toggle now cycles through configured roles (chat, coder,
research, distill, etc.) instead of backup model slots within the chat
role. Each role uses its own primary→backup chain from the registry.

- ChatRequest.slot replaced by chat_role (default "chat")
- GET /backend returns available_roles instead of chat_models
- _available_roles_for_toggle() builds list from defined_roles, excluding
  orchestrator (which has its own Agent mode)
- Model label on responses now reflects the actual role's assigned model
- Toggle is inert when only one role is configured (avoids useless cycling)
- Add "Clear browser cache" button to Account Settings (Connected Accounts)
- Add _role_model_label() helper for cleaner response tag labeling

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-28 19:23:18 -04:00

359 lines
12 KiB
Python

import asyncio
import json
import platform
import jwt
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from context_loader import load_context
from llm_client import complete
from session_logger import log_turn
from session_store import load as load_session, save as save_session, list_all, generate_session_id, delete as delete_session, rename as rename_session
from config import settings
from persona import set_context, validate as validate_persona
from auth_utils import COOKIE_NAME, decode_token
import model_registry
import event_bus
router = APIRouter()
def _backend_label(backend: str, username: str, role: str = "chat") -> str:
"""Human-readable label for the model that handled a request (legacy path)."""
if backend == "claude":
return "Claude"
if backend == "gemini":
return "Gemini"
if backend == "local":
cfg = model_registry.get_best_local_model(username, role)
if cfg:
return cfg.get("label") or cfg.get("model_name") or "Local"
return "Local"
return backend.title()
def _role_model_label(username: str, role: str, actual_backend: str) -> str:
"""Return the model label for a role, falling back to the generic backend label."""
cfg = model_registry.get_model_for_role(username, role)
if cfg:
return cfg.get("label") or cfg.get("model_name") or _backend_label(actual_backend, username, role)
return _backend_label(actual_backend, username, role)
class ChatRequest(BaseModel):
message: str
session_id: str | None = None
tier: int | None = None
model: str | None = None # legacy backend override ("claude"|"gemini"|"local")
chat_role: str = "chat" # active role: "chat"|"coder"|"research"|"distill" etc.
include_long: bool = True
include_mid: bool = True
include_short: bool = True
off_record: bool = False # skip session log (in-memory context preserved)
user: str = "scott"
persona: str = "inara"
class BackendRequest(BaseModel):
primary: str # "claude", "gemini", or "local"
class NoteRequest(BaseModel):
session_id: str
note: str
class HistoryUpdate(BaseModel):
messages: list[dict]
async def _stream_chat(req: ChatRequest):
"""
SSE generator: sends keepalive events every 3s while the LLM works,
then sends the final response. Keeps the browser connection alive
regardless of how long the backend takes.
Event types:
data: {"type": "keepalive"}
data: {"type": "response", "response": "...", "session_id": "...",
"backend": "...", "fallback_used": bool}
data: {"type": "error", "message": "..."}
"""
try:
user, persona = validate_persona(req.user, req.persona)
set_context(user, persona)
except ValueError as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
return
session_id = req.session_id or generate_session_id()
tier = req.tier or settings.default_tier
system_prompt = load_context(
tier,
include_long=req.include_long,
include_mid=req.include_mid,
include_short=req.include_short,
)
history = load_session(session_id)
history.append({"role": "user", "content": req.message})
task = asyncio.create_task(complete(
system_prompt=system_prompt,
messages=history,
model=req.model,
role=req.chat_role,
))
try:
# Ping the browser every 3s so it doesn't drop the connection
while not task.done():
yield 'data: {"type":"keepalive"}\n\n'
try:
await asyncio.wait_for(asyncio.shield(task), timeout=3)
except asyncio.TimeoutError:
pass
except Exception:
break
try:
response_text, actual_backend = task.result()
backend_label = _role_model_label(user, req.chat_role, actual_backend)
host = platform.node()
history.append({
"role": "assistant",
"content": response_text,
"backend": actual_backend,
"backend_label": backend_label,
"host": host,
})
save_session(session_id, history)
if not req.off_record:
log_turn(session_id, req.message, response_text, backend_label, host)
# fallback_used only makes sense for explicit backend selections.
# In auto mode (req.model is None), just report what responded.
fallback_used = bool(req.model and actual_backend != req.model)
payload = {
"type": "response",
"response": response_text,
"session_id": session_id,
"backend": actual_backend,
"backend_label": backend_label,
"host": host,
"fallback_used": fallback_used,
}
yield f"data: {json.dumps(payload)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
finally:
# Ensure the LLM task is cancelled if the generator is torn down
# (e.g. client disconnect or server shutdown). This propagates
# CancelledError into _gemini() which kills the process group.
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
@router.post("/chat")
async def chat(req: ChatRequest) -> StreamingResponse:
return StreamingResponse(
_stream_chat(req),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
_BACKEND_CYCLE = ("claude", "gemini", "local")
_BACKEND_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"}
def _request_user(request: Request) -> str | None:
"""Extract username from JWT cookie, or None."""
try:
token = request.cookies.get(COOKIE_NAME)
return decode_token(token) if token else None
except (jwt.InvalidTokenError, Exception):
return None
def _local_model_info(request: Request) -> dict | None:
"""Return the best local model {label, model_name} for the session user, or None."""
username = _request_user(request)
if not username:
return None
try:
cfg = model_registry.get_best_local_model(username, "chat")
if cfg:
return {"label": cfg.get("label", ""), "model_name": cfg.get("model_name", "")}
except Exception:
pass
return None
def _available_roles_for_toggle(username: str) -> list[dict]:
"""Return roles with a primary model assigned (excluding orchestrator) for the UI toggle.
Returns [{role, label, model_label, type}] ordered by settings.defined_roles.
"""
registry = model_registry.get_registry(username)
roles_cfg = registry.get("roles", {})
result = []
for role_name in settings.get_defined_roles():
if role_name == "orchestrator":
continue
primary_id = roles_cfg.get(role_name, {}).get("primary")
if not primary_id:
continue
resolved = model_registry._resolve_model(registry, primary_id)
if resolved:
result.append({
"role": role_name,
"label": role_name.title(),
"model_label": resolved.get("label") or resolved.get("model_name") or "",
"type": resolved.get("type", ""),
})
return result
@router.get("/backend")
async def get_backend(request: Request) -> dict:
username = _request_user(request)
available_roles = _available_roles_for_toggle(username) if username else []
p = settings.primary_backend
return {
"available_roles": available_roles,
# Legacy fields kept for backward compat
"primary": p,
"fallback": _BACKEND_FALLBACK.get(p, "claude"),
"local_model": _local_model_info(request),
}
@router.post("/backend")
async def set_backend(req: BackendRequest, request: Request) -> dict:
if req.primary not in _BACKEND_CYCLE:
raise HTTPException(status_code=400, detail="primary must be 'claude', 'gemini', or 'local'")
settings.primary_backend = req.primary
return {
"primary": req.primary,
"fallback": _BACKEND_FALLBACK[req.primary],
"local_model": _local_model_info(request),
}
def _set_ctx(user: str, persona: str) -> None:
"""Validate and set persona context from query params. Raises HTTPException on bad input."""
try:
u, p = validate_persona(user, persona)
set_context(u, p)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/history/{session_id}")
async def get_history(
session_id: str,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
return {"session_id": session_id, "messages": load_session(session_id)}
@router.get("/sessions")
async def list_sessions(
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
return {"sessions": list_all()}
class SessionRename(BaseModel):
name: str
@router.patch("/sessions/{session_id}")
async def rename_session_endpoint(
session_id: str,
req: SessionRename,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
found = rename_session(session_id, req.name.strip())
if not found:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"ok": True, "session_id": session_id, "name": req.name.strip()}
@router.delete("/sessions/{session_id}")
async def delete_session_endpoint(
session_id: str,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
found = delete_session(session_id)
if not found:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"ok": True, "session_id": session_id}
@router.put("/history/{session_id}")
async def replace_history(
session_id: str,
req: HistoryUpdate,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
"""Replace the full message list for a session (used by edit/delete UI)."""
_set_ctx(user, persona)
save_session(session_id, req.messages)
return {"ok": True, "session_id": session_id}
@router.get("/events")
async def sse_events() -> StreamingResponse:
"""Server-sent events stream — pushes real-time Talk activity to the browser."""
async def stream():
q = event_bus.subscribe()
try:
while True:
try:
event = await asyncio.wait_for(q.get(), timeout=20)
yield f"data: {json.dumps(event)}\n\n"
except asyncio.TimeoutError:
yield 'data: {"type":"keepalive"}\n\n'
except (GeneratorExit, asyncio.CancelledError):
pass
finally:
event_bus.unsubscribe(q)
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.post("/note")
async def add_note(
req: NoteRequest,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
"""Inject a public note into session history so the LLM sees it next turn."""
_set_ctx(user, persona)
history = load_session(req.session_id)
history.append({"role": "user", "content": f"[NOTE] {req.note}"})
save_session(req.session_id, history)
return {"ok": True, "session_id": req.session_id}