Files
Cortex-Inara/cortex/routers/chat.py
Scott Idem 96b3c796c5 feat: file attachment support in chat (images + text/code files)
Text files (.md, .py, .js, .json, etc.): read client-side and injected
into the message body as a fenced code block — works with all backends
with zero model capability requirements.

Images (PNG/JPG/WebP/GIF, max 5 MB): encoded as base64 data URL on the
client and sent as a separate attachment field. Backend formats them as
OpenAI multimodal content (text + image_url) for local_openai backends.
Claude CLI and Gemini CLI see the text message with a "📎 filename.png"
note; image data is never written to session history.

- index.html: 📎 button + hidden file input in mode-select row;
  attachment-row preview area with thumbnail (images) or filename chip
- app.js: _resolveAttachment(), file reader, clearAttachment();
  sendMessage/sendOrchestrate updated to allow no-text sends when a
  file is pending; attachment spread into chat payload for images
- chat.py: Attachment model; attachment field on ChatRequest;
  llm_attachment extracted in _stream_chat and passed to complete()
- llm_client.py: attachment param through complete()/_dispatch()/_local();
  _local() builds multimodal content array for vision calls
- style.css: #attach-btn, #attachment-row, #attachment-preview, thumb

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 21:46:50 -04:00

469 lines
16 KiB
Python

import asyncio
import json
import platform
import jwt
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from context_loader import load_context
from llm_client import complete
from session_logger import log_turn
from session_store import load as load_session, save as save_session, list_all, generate_session_id, delete as delete_session, rename as rename_session, get_name as get_session_name
from config import settings
from persona import set_context, validate as validate_persona
from auth_utils import COOKIE_NAME, decode_token
import model_registry
import event_bus
from model_registry import get_role_config
router = APIRouter()
def _backend_label(backend: str, username: str, role: str = "chat") -> str:
"""Human-readable label for the model that handled a request (legacy path)."""
if backend == "claude":
return "Claude"
if backend == "gemini":
return "Gemini"
if backend == "local":
cfg = model_registry.get_best_local_model(username, role)
if cfg:
return cfg.get("label") or cfg.get("model_name") or "Local"
return "Local"
return backend.title()
def _role_model_label(username: str, role: str, actual_backend: str) -> str:
"""Return the model label for a role, falling back to the generic backend label."""
cfg = model_registry.get_model_for_role(username, role)
if cfg:
return cfg.get("label") or cfg.get("model_name") or _backend_label(actual_backend, username, role)
return _backend_label(actual_backend, username, role)
class Attachment(BaseModel):
filename: str
mime_type: str
data: str # base64 data URL for images (e.g. "data:image/png;base64,...")
class ChatRequest(BaseModel):
message: str
session_id: str | None = None
tier: int | None = None
model: str | None = None # legacy backend override ("claude"|"gemini"|"local")
slot: str | None = None # Phase 3: explicit slot ("primary"|"backup_1"|"backup_2")
chat_role: str = "chat" # active role: "chat"|"coder"|"research"|"distill" etc.
include_long: bool = True
include_mid: bool = True
include_short: bool = True
off_record: bool = False # skip session log (in-memory context preserved)
user: str = "scott"
persona: str = "inara"
attachment: Attachment | None = None # image attachment (text files injected client-side)
class BackendRequest(BaseModel):
primary: str # "claude", "gemini", or "local"
class NoteRequest(BaseModel):
session_id: str
note: str
class HistoryUpdate(BaseModel):
messages: list[dict]
async def _stream_chat(req: ChatRequest):
"""
SSE generator: sends keepalive events every 3s while the LLM works,
then sends the final response. Keeps the browser connection alive
regardless of how long the backend takes.
Event types:
data: {"type": "keepalive"}
data: {"type": "response", "response": "...", "session_id": "...",
"backend": "...", "fallback_used": bool}
data: {"type": "error", "message": "..."}
"""
try:
user, persona = validate_persona(req.user, req.persona)
set_context(user, persona)
except ValueError as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
return
session_id = req.session_id or generate_session_id()
tier = req.tier or settings.default_tier
role_cfg = get_role_config(user, req.chat_role)
system_prompt = load_context(
tier,
include_long=req.include_long,
include_mid=req.include_mid,
include_short=req.include_short,
inject_datetime=role_cfg.get("inject_datetime", True),
inject_mode=role_cfg.get("inject_mode", True),
mode="otr" if req.off_record else "chat",
)
history = load_session(session_id)
# req.message already contains the full user text:
# - text files: client embedded content as a fenced code block
# - images: client added "📎 filename.png" note; image data is in req.attachment
# History always stores text only — base64 image data is never written to disk.
llm_attachment: dict | None = None
if req.attachment and req.attachment.mime_type.startswith("image/"):
llm_attachment = {
"filename": req.attachment.filename,
"mime_type": req.attachment.mime_type,
"data": req.attachment.data,
}
history.append({"role": "user", "content": req.message, "off_record": req.off_record})
task = asyncio.create_task(complete(
system_prompt=system_prompt,
messages=history,
model=req.model,
role=req.chat_role,
slot=req.slot,
attachment=llm_attachment,
))
try:
# Ping the browser every 3s so it doesn't drop the connection
while not task.done():
yield 'data: {"type":"keepalive"}\n\n'
try:
await asyncio.wait_for(asyncio.shield(task), timeout=3)
except asyncio.TimeoutError:
pass
except Exception:
break
try:
response_text, actual_backend = task.result()
if req.slot:
slot_cfg = model_registry.get_model_for_slot(user, req.chat_role, req.slot)
backend_label = (slot_cfg or {}).get("label") or _role_model_label(user, req.chat_role, actual_backend)
else:
backend_label = _role_model_label(user, req.chat_role, actual_backend)
host = platform.node()
history.append({
"role": "assistant",
"content": response_text,
"backend": actual_backend,
"backend_label": backend_label,
"host": host,
"off_record": req.off_record,
})
save_session(session_id, history)
if not req.off_record:
log_turn(session_id, req.message, response_text, backend_label, host)
# fallback_used only makes sense for explicit backend selections.
# In auto mode (req.model is None), just report what responded.
fallback_used = bool(req.model and actual_backend != req.model)
payload = {
"type": "response",
"response": response_text,
"session_id": session_id,
"backend": actual_backend,
"backend_label": backend_label,
"host": host,
"fallback_used": fallback_used,
}
yield f"data: {json.dumps(payload)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
finally:
# Ensure the LLM task is cancelled if the generator is torn down
# (e.g. client disconnect or server shutdown). This propagates
# CancelledError into _gemini() which kills the process group.
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
@router.post("/chat")
async def chat(req: ChatRequest) -> StreamingResponse:
return StreamingResponse(
_stream_chat(req),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
_BACKEND_CYCLE = ("claude", "gemini", "local")
_BACKEND_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"}
def _request_user(request: Request) -> str | None:
"""Extract username from JWT cookie, or None."""
try:
token = request.cookies.get(COOKIE_NAME)
return decode_token(token) if token else None
except (jwt.InvalidTokenError, Exception):
return None
def _local_model_info(request: Request) -> dict | None:
"""Return the best local model {label, model_name} for the session user, or None."""
username = _request_user(request)
if not username:
return None
try:
cfg = model_registry.get_best_local_model(username, "chat")
if cfg:
return {"label": cfg.get("label", ""), "model_name": cfg.get("model_name", "")}
except Exception:
pass
return None
def _chat_slot_models(username: str) -> list[dict]:
"""Return [{slot, label, type}] for each configured slot in the chat role, primary first."""
registry = model_registry.get_registry(username)
role_slots = registry.get("roles", {}).get("chat", {})
result = []
for slot_key in model_registry.PRIORITY_KEYS:
model_id = role_slots.get(slot_key)
if not model_id:
continue
resolved = model_registry._resolve_model(registry, model_id)
if resolved:
result.append({
"slot": slot_key,
"label": resolved.get("label") or resolved.get("model_name") or "",
"type": resolved.get("type", ""),
})
return result
def _available_roles_for_toggle(username: str) -> list[dict]:
"""Return roles with a primary model assigned (excluding orchestrator) for the UI toggle.
Returns [{role, label, model_label, type}] ordered by settings.defined_roles.
"""
registry = model_registry.get_registry(username)
roles_cfg = registry.get("roles", {})
result = []
for role_name in settings.get_defined_roles():
if role_name == "orchestrator":
continue
primary_id = roles_cfg.get(role_name, {}).get("primary")
if not primary_id:
continue
resolved = model_registry._resolve_model(registry, primary_id)
if resolved:
result.append({
"role": role_name,
"label": role_name.title(),
"model_label": resolved.get("label") or resolved.get("model_name") or "",
"type": resolved.get("type", ""),
})
return result
@router.get("/backend")
async def get_backend(request: Request) -> dict:
username = _request_user(request)
chat_models = _chat_slot_models(username) if username else []
available_roles = _available_roles_for_toggle(username) if username else []
p = settings.primary_backend
orch_label = None
if username:
orch_cfg = model_registry.get_model_for_role(username, "orchestrator")
if orch_cfg:
orch_label = orch_cfg.get("label") or orch_cfg.get("model_name") or None
return {
"chat_models": chat_models, # Phase 3: [{slot, label, type}] for chat-role slots
"available_roles": available_roles, # kept for banner + backward compat
"orchestrator_model": orch_label,
# Legacy fields kept for backward compat
"primary": p,
"fallback": _BACKEND_FALLBACK.get(p, "claude"),
"local_model": _local_model_info(request),
}
@router.post("/backend")
async def set_backend(req: BackendRequest, request: Request) -> dict:
if req.primary not in _BACKEND_CYCLE:
raise HTTPException(status_code=400, detail="primary must be 'claude', 'gemini', or 'local'")
settings.primary_backend = req.primary
return {
"primary": req.primary,
"fallback": _BACKEND_FALLBACK[req.primary],
"local_model": _local_model_info(request),
}
def _set_ctx(user: str, persona: str) -> None:
"""Validate and set persona context from query params. Raises HTTPException on bad input."""
try:
u, p = validate_persona(user, persona)
set_context(u, p)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/history/{session_id}")
async def get_history(
session_id: str,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
name = get_session_name(session_id)
return {"session_id": session_id, "name": name, "messages": load_session(session_id)}
@router.get("/sessions")
async def list_sessions(
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
return {"sessions": list_all()}
class SessionRename(BaseModel):
name: str
@router.patch("/sessions/{session_id}")
async def rename_session_endpoint(
session_id: str,
req: SessionRename,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
found = rename_session(session_id, req.name.strip())
if not found:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"ok": True, "session_id": session_id, "name": req.name.strip()}
@router.post("/api/sessions/backfill-names")
async def backfill_session_names(
request: Request,
user: str = Query(""),
persona: str = Query(""),
) -> dict:
"""Name every unnamed session using its first user message (truncated to 60 chars).
Idempotent — only touches sessions that have no name set.
user/persona default to the JWT session user + last-used persona cookie."""
# Resolve user from JWT if not provided
if not user:
token = request.cookies.get(COOKIE_NAME)
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
user = decode_token(token)
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid session")
# Resolve persona from cookie if not provided
if not persona:
from persona import list_user_personas
persona_cookie = request.cookies.get("cx_last_persona", "")
available = list_user_personas(user)
persona = persona_cookie if persona_cookie in available else (available[0] if available else "")
if not persona:
raise HTTPException(status_code=400, detail="No persona found for user")
_set_ctx(user, persona)
sessions = list_all()
named = 0
for s in sessions:
if s.get("name"):
continue
messages = load_session(s["session_id"])
first_user = next((m for m in messages if m.get("role") == "user"), None)
if not first_user:
continue
text = (first_user.get("content") or "").strip()
if not text:
continue
auto_name = text[:60].rstrip() + ("" if len(text) > 60 else "")
rename_session(s["session_id"], auto_name)
named += 1
return {"ok": True, "named": named, "total": len(sessions)}
@router.delete("/sessions/{session_id}")
async def delete_session_endpoint(
session_id: str,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
_set_ctx(user, persona)
found = delete_session(session_id)
if not found:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
return {"ok": True, "session_id": session_id}
@router.put("/history/{session_id}")
async def replace_history(
session_id: str,
req: HistoryUpdate,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
"""Replace the full message list for a session (used by edit/delete UI)."""
_set_ctx(user, persona)
save_session(session_id, req.messages)
return {"ok": True, "session_id": session_id}
@router.get("/events")
async def sse_events() -> StreamingResponse:
"""Server-sent events stream — pushes real-time Talk activity to the browser."""
async def stream():
q = event_bus.subscribe()
try:
while True:
try:
event = await asyncio.wait_for(q.get(), timeout=20)
yield f"data: {json.dumps(event)}\n\n"
except asyncio.TimeoutError:
yield 'data: {"type":"keepalive"}\n\n'
except (GeneratorExit, asyncio.CancelledError):
pass
finally:
event_bus.unsubscribe(q)
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.post("/note")
async def add_note(
req: NoteRequest,
user: str = Query("scott"),
persona: str = Query("inara"),
) -> dict:
"""Inject a public note into session history so the LLM sees it next turn."""
_set_ctx(user, persona)
history = load_session(req.session_id)
history.append({"role": "user", "content": f"[NOTE] {req.note}"})
save_session(req.session_id, history)
return {"ok": True, "session_id": req.session_id}