/history/{session_id} now returns a 'name' field alongside messages.
resumeSession() uses data.name first, then the sessionNames map, then
raw ID as fallback — so named sessions display correctly even on page
load before the sessions panel has been opened.
'Resumed session X' message also now shows the friendly name.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
360 lines
12 KiB
Python
360 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import platform
|
|
import jwt
|
|
from fastapi import APIRouter, HTTPException, Query, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
from context_loader import load_context
|
|
from llm_client import complete
|
|
from session_logger import log_turn
|
|
from session_store import load as load_session, save as save_session, list_all, generate_session_id, delete as delete_session, rename as rename_session, get_name as get_session_name
|
|
from config import settings
|
|
from persona import set_context, validate as validate_persona
|
|
from auth_utils import COOKIE_NAME, decode_token
|
|
import model_registry
|
|
import event_bus
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def _backend_label(backend: str, username: str, role: str = "chat") -> str:
|
|
"""Human-readable label for the model that handled a request (legacy path)."""
|
|
if backend == "claude":
|
|
return "Claude"
|
|
if backend == "gemini":
|
|
return "Gemini"
|
|
if backend == "local":
|
|
cfg = model_registry.get_best_local_model(username, role)
|
|
if cfg:
|
|
return cfg.get("label") or cfg.get("model_name") or "Local"
|
|
return "Local"
|
|
return backend.title()
|
|
|
|
|
|
def _role_model_label(username: str, role: str, actual_backend: str) -> str:
|
|
"""Return the model label for a role, falling back to the generic backend label."""
|
|
cfg = model_registry.get_model_for_role(username, role)
|
|
if cfg:
|
|
return cfg.get("label") or cfg.get("model_name") or _backend_label(actual_backend, username, role)
|
|
return _backend_label(actual_backend, username, role)
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
message: str
|
|
session_id: str | None = None
|
|
tier: int | None = None
|
|
model: str | None = None # legacy backend override ("claude"|"gemini"|"local")
|
|
chat_role: str = "chat" # active role: "chat"|"coder"|"research"|"distill" etc.
|
|
include_long: bool = True
|
|
include_mid: bool = True
|
|
include_short: bool = True
|
|
off_record: bool = False # skip session log (in-memory context preserved)
|
|
user: str = "scott"
|
|
persona: str = "inara"
|
|
|
|
|
|
class BackendRequest(BaseModel):
|
|
primary: str # "claude", "gemini", or "local"
|
|
|
|
|
|
class NoteRequest(BaseModel):
|
|
session_id: str
|
|
note: str
|
|
|
|
|
|
class HistoryUpdate(BaseModel):
|
|
messages: list[dict]
|
|
|
|
|
|
async def _stream_chat(req: ChatRequest):
|
|
"""
|
|
SSE generator: sends keepalive events every 3s while the LLM works,
|
|
then sends the final response. Keeps the browser connection alive
|
|
regardless of how long the backend takes.
|
|
|
|
Event types:
|
|
data: {"type": "keepalive"}
|
|
data: {"type": "response", "response": "...", "session_id": "...",
|
|
"backend": "...", "fallback_used": bool}
|
|
data: {"type": "error", "message": "..."}
|
|
"""
|
|
try:
|
|
user, persona = validate_persona(req.user, req.persona)
|
|
set_context(user, persona)
|
|
except ValueError as e:
|
|
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
|
return
|
|
|
|
session_id = req.session_id or generate_session_id()
|
|
tier = req.tier or settings.default_tier
|
|
|
|
system_prompt = load_context(
|
|
tier,
|
|
include_long=req.include_long,
|
|
include_mid=req.include_mid,
|
|
include_short=req.include_short,
|
|
)
|
|
history = load_session(session_id)
|
|
history.append({"role": "user", "content": req.message})
|
|
|
|
task = asyncio.create_task(complete(
|
|
system_prompt=system_prompt,
|
|
messages=history,
|
|
model=req.model,
|
|
role=req.chat_role,
|
|
))
|
|
|
|
try:
|
|
# Ping the browser every 3s so it doesn't drop the connection
|
|
while not task.done():
|
|
yield 'data: {"type":"keepalive"}\n\n'
|
|
try:
|
|
await asyncio.wait_for(asyncio.shield(task), timeout=3)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
except Exception:
|
|
break
|
|
|
|
try:
|
|
response_text, actual_backend = task.result()
|
|
backend_label = _role_model_label(user, req.chat_role, actual_backend)
|
|
host = platform.node()
|
|
history.append({
|
|
"role": "assistant",
|
|
"content": response_text,
|
|
"backend": actual_backend,
|
|
"backend_label": backend_label,
|
|
"host": host,
|
|
})
|
|
save_session(session_id, history)
|
|
if not req.off_record:
|
|
log_turn(session_id, req.message, response_text, backend_label, host)
|
|
|
|
# fallback_used only makes sense for explicit backend selections.
|
|
# In auto mode (req.model is None), just report what responded.
|
|
fallback_used = bool(req.model and actual_backend != req.model)
|
|
payload = {
|
|
"type": "response",
|
|
"response": response_text,
|
|
"session_id": session_id,
|
|
"backend": actual_backend,
|
|
"backend_label": backend_label,
|
|
"host": host,
|
|
"fallback_used": fallback_used,
|
|
}
|
|
yield f"data: {json.dumps(payload)}\n\n"
|
|
|
|
except Exception as e:
|
|
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
|
|
|
finally:
|
|
# Ensure the LLM task is cancelled if the generator is torn down
|
|
# (e.g. client disconnect or server shutdown). This propagates
|
|
# CancelledError into _gemini() which kills the process group.
|
|
if not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
|
|
@router.post("/chat")
|
|
async def chat(req: ChatRequest) -> StreamingResponse:
|
|
return StreamingResponse(
|
|
_stream_chat(req),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
_BACKEND_CYCLE = ("claude", "gemini", "local")
|
|
_BACKEND_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"}
|
|
|
|
|
|
def _request_user(request: Request) -> str | None:
|
|
"""Extract username from JWT cookie, or None."""
|
|
try:
|
|
token = request.cookies.get(COOKIE_NAME)
|
|
return decode_token(token) if token else None
|
|
except (jwt.InvalidTokenError, Exception):
|
|
return None
|
|
|
|
|
|
def _local_model_info(request: Request) -> dict | None:
|
|
"""Return the best local model {label, model_name} for the session user, or None."""
|
|
username = _request_user(request)
|
|
if not username:
|
|
return None
|
|
try:
|
|
cfg = model_registry.get_best_local_model(username, "chat")
|
|
if cfg:
|
|
return {"label": cfg.get("label", ""), "model_name": cfg.get("model_name", "")}
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _available_roles_for_toggle(username: str) -> list[dict]:
|
|
"""Return roles with a primary model assigned (excluding orchestrator) for the UI toggle.
|
|
|
|
Returns [{role, label, model_label, type}] ordered by settings.defined_roles.
|
|
"""
|
|
registry = model_registry.get_registry(username)
|
|
roles_cfg = registry.get("roles", {})
|
|
result = []
|
|
for role_name in settings.get_defined_roles():
|
|
if role_name == "orchestrator":
|
|
continue
|
|
primary_id = roles_cfg.get(role_name, {}).get("primary")
|
|
if not primary_id:
|
|
continue
|
|
resolved = model_registry._resolve_model(registry, primary_id)
|
|
if resolved:
|
|
result.append({
|
|
"role": role_name,
|
|
"label": role_name.title(),
|
|
"model_label": resolved.get("label") or resolved.get("model_name") or "",
|
|
"type": resolved.get("type", ""),
|
|
})
|
|
return result
|
|
|
|
|
|
@router.get("/backend")
|
|
async def get_backend(request: Request) -> dict:
|
|
username = _request_user(request)
|
|
available_roles = _available_roles_for_toggle(username) if username else []
|
|
p = settings.primary_backend
|
|
return {
|
|
"available_roles": available_roles,
|
|
# Legacy fields kept for backward compat
|
|
"primary": p,
|
|
"fallback": _BACKEND_FALLBACK.get(p, "claude"),
|
|
"local_model": _local_model_info(request),
|
|
}
|
|
|
|
|
|
@router.post("/backend")
|
|
async def set_backend(req: BackendRequest, request: Request) -> dict:
|
|
if req.primary not in _BACKEND_CYCLE:
|
|
raise HTTPException(status_code=400, detail="primary must be 'claude', 'gemini', or 'local'")
|
|
settings.primary_backend = req.primary
|
|
return {
|
|
"primary": req.primary,
|
|
"fallback": _BACKEND_FALLBACK[req.primary],
|
|
"local_model": _local_model_info(request),
|
|
}
|
|
|
|
|
|
def _set_ctx(user: str, persona: str) -> None:
|
|
"""Validate and set persona context from query params. Raises HTTPException on bad input."""
|
|
try:
|
|
u, p = validate_persona(user, persona)
|
|
set_context(u, p)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@router.get("/history/{session_id}")
|
|
async def get_history(
|
|
session_id: str,
|
|
user: str = Query("scott"),
|
|
persona: str = Query("inara"),
|
|
) -> dict:
|
|
_set_ctx(user, persona)
|
|
name = get_session_name(session_id)
|
|
return {"session_id": session_id, "name": name, "messages": load_session(session_id)}
|
|
|
|
|
|
@router.get("/sessions")
|
|
async def list_sessions(
|
|
user: str = Query("scott"),
|
|
persona: str = Query("inara"),
|
|
) -> dict:
|
|
_set_ctx(user, persona)
|
|
return {"sessions": list_all()}
|
|
|
|
|
|
class SessionRename(BaseModel):
|
|
name: str
|
|
|
|
|
|
@router.patch("/sessions/{session_id}")
|
|
async def rename_session_endpoint(
|
|
session_id: str,
|
|
req: SessionRename,
|
|
user: str = Query("scott"),
|
|
persona: str = Query("inara"),
|
|
) -> dict:
|
|
_set_ctx(user, persona)
|
|
found = rename_session(session_id, req.name.strip())
|
|
if not found:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
return {"ok": True, "session_id": session_id, "name": req.name.strip()}
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def delete_session_endpoint(
|
|
session_id: str,
|
|
user: str = Query("scott"),
|
|
persona: str = Query("inara"),
|
|
) -> dict:
|
|
_set_ctx(user, persona)
|
|
found = delete_session(session_id)
|
|
if not found:
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
return {"ok": True, "session_id": session_id}
|
|
|
|
|
|
@router.put("/history/{session_id}")
|
|
async def replace_history(
|
|
session_id: str,
|
|
req: HistoryUpdate,
|
|
user: str = Query("scott"),
|
|
persona: str = Query("inara"),
|
|
) -> dict:
|
|
"""Replace the full message list for a session (used by edit/delete UI)."""
|
|
_set_ctx(user, persona)
|
|
save_session(session_id, req.messages)
|
|
return {"ok": True, "session_id": session_id}
|
|
|
|
|
|
@router.get("/events")
|
|
async def sse_events() -> StreamingResponse:
|
|
"""Server-sent events stream — pushes real-time Talk activity to the browser."""
|
|
async def stream():
|
|
q = event_bus.subscribe()
|
|
try:
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(q.get(), timeout=20)
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
except asyncio.TimeoutError:
|
|
yield 'data: {"type":"keepalive"}\n\n'
|
|
except (GeneratorExit, asyncio.CancelledError):
|
|
pass
|
|
finally:
|
|
event_bus.unsubscribe(q)
|
|
|
|
return StreamingResponse(
|
|
stream(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
@router.post("/note")
|
|
async def add_note(
|
|
req: NoteRequest,
|
|
user: str = Query("scott"),
|
|
persona: str = Query("inara"),
|
|
) -> dict:
|
|
"""Inject a public note into session history so the LLM sees it next turn."""
|
|
_set_ctx(user, persona)
|
|
history = load_session(req.session_id)
|
|
history.append({"role": "user", "content": f"[NOTE] {req.note}"})
|
|
save_session(req.session_id, history)
|
|
return {"ok": True, "session_id": req.session_id}
|