Files
Cortex-Inara/cortex/routers/chat.py
Scott Idem ce3c1f5f7f Add tiered memory system with manual distillation
- config.py: memory_budget_long/mid/short settings (overridable in .env)
- memory_distiller.py: distill_short (no LLM), distill_mid, distill_long (LLM)
- routers/distill.py: POST /distill/{short,mid,long,all} endpoints
- context_loader.py: rewrote to load long→mid→short order with include_* toggles
- routers/chat.py: ChatRequest gains include_long/mid/short fields
- routers/files.py: MEMORY_LONG/MID/SHORT.md added to ALLOWED set
- main.py: register distill router
- static/index.html: context bar — tier selector, L/M/S memory toggles,
  distill buttons with status feedback; send includes tier + memory flags
- inara/MEMORY_LONG.md: migrated from MEMORY.md + Cortex/Talk bot notes
- inara/MEMORY_MID.md, MEMORY_SHORT.md: stubs ready for distillation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 21:22:32 -04:00

184 lines
5.8 KiB
Python

import asyncio
import json
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from context_loader import load_context
from llm_client import complete
from session_logger import log_turn
from session_store import load as load_session, save as save_session, list_all, generate_session_id
from config import settings
import event_bus
router = APIRouter()
class ChatRequest(BaseModel):
message: str
session_id: str | None = None
tier: int | None = None
model: str | None = None # "claude" or "gemini" to override; None = use primary_backend
include_long: bool = True
include_mid: bool = True
include_short: bool = True
class BackendRequest(BaseModel):
primary: str # "claude" or "gemini"
class NoteRequest(BaseModel):
session_id: str
note: str
class HistoryUpdate(BaseModel):
messages: list[dict]
async def _stream_chat(req: ChatRequest):
"""
SSE generator: sends keepalive events every 3s while the LLM works,
then sends the final response. Keeps the browser connection alive
regardless of how long the backend takes.
Event types:
data: {"type": "keepalive"}
data: {"type": "response", "response": "...", "session_id": "...",
"backend": "...", "fallback_used": bool}
data: {"type": "error", "message": "..."}
"""
session_id = req.session_id or generate_session_id()
tier = req.tier or settings.default_tier
system_prompt = load_context(
tier,
include_long=req.include_long,
include_mid=req.include_mid,
include_short=req.include_short,
)
history = load_session(session_id)
history.append({"role": "user", "content": req.message})
task = asyncio.create_task(complete(
system_prompt=system_prompt,
messages=history,
model=req.model,
))
try:
# Ping the browser every 3s so it doesn't drop the connection
while not task.done():
yield 'data: {"type":"keepalive"}\n\n'
try:
await asyncio.wait_for(asyncio.shield(task), timeout=3)
except asyncio.TimeoutError:
pass
except Exception:
break
try:
response_text, actual_backend = task.result()
history.append({"role": "assistant", "content": response_text})
save_session(session_id, history)
log_turn(session_id, req.message, response_text)
requested = req.model or settings.primary_backend
payload = {
"type": "response",
"response": response_text,
"session_id": session_id,
"backend": actual_backend,
"fallback_used": actual_backend != requested,
}
yield f"data: {json.dumps(payload)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
finally:
# Ensure the LLM task is cancelled if the generator is torn down
# (e.g. client disconnect or server shutdown). This propagates
# CancelledError into _gemini() which kills the process group.
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
@router.post("/chat")
async def chat(req: ChatRequest) -> StreamingResponse:
return StreamingResponse(
_stream_chat(req),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.get("/backend")
async def get_backend() -> dict:
other = "gemini" if settings.primary_backend == "claude" else "claude"
return {"primary": settings.primary_backend, "fallback": other}
@router.post("/backend")
async def set_backend(req: BackendRequest) -> dict:
if req.primary not in ("claude", "gemini"):
raise HTTPException(status_code=400, detail="primary must be 'claude' or 'gemini'")
settings.primary_backend = req.primary
other = "gemini" if req.primary == "claude" else "claude"
return {"primary": settings.primary_backend, "fallback": other}
@router.get("/history/{session_id}")
async def get_history(session_id: str) -> dict:
return {"session_id": session_id, "messages": load_session(session_id)}
@router.get("/sessions")
async def list_sessions() -> dict:
return {"sessions": list_all()}
@router.put("/history/{session_id}")
async def replace_history(session_id: str, req: HistoryUpdate) -> dict:
"""Replace the full message list for a session (used by edit/delete UI)."""
save_session(session_id, req.messages)
return {"ok": True, "session_id": session_id}
@router.get("/events")
async def sse_events() -> StreamingResponse:
"""Server-sent events stream — pushes real-time Talk activity to the browser."""
async def stream():
q = event_bus.subscribe()
try:
while True:
try:
event = await asyncio.wait_for(q.get(), timeout=20)
yield f"data: {json.dumps(event)}\n\n"
except asyncio.TimeoutError:
yield 'data: {"type":"keepalive"}\n\n'
except (GeneratorExit, asyncio.CancelledError):
pass
finally:
event_bus.unsubscribe(q)
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.post("/note")
async def add_note(req: NoteRequest) -> dict:
"""Inject a public note into session history so the LLM sees it next turn."""
history = load_session(req.session_id)
history.append({"role": "user", "content": f"[NOTE] {req.note}"})
save_session(req.session_id, history)
return {"ok": True, "session_id": req.session_id}