""" Manual memory distillation endpoints. POST /distill/short — roll session logs → MEMORY_SHORT.md (no LLM) POST /distill/mid — summarize short → MEMORY_MID.md (LLM) POST /distill/long — integrate mid → MEMORY_LONG.md (LLM) POST /distill/all — run all three in sequence POST /distill/rebuild — wipe mid + long, then run all three from scratch All endpoints require ?user=&persona= query params. Concurrency: one distillation at a time per persona. A second request while one is running returns 409 immediately — no silent queuing. """ import asyncio from datetime import datetime, timedelta from fastapi import APIRouter, HTTPException, Query from memory_distiller import distill_short, distill_mid, distill_long from persona import validate as validate_persona, set_context, persona_path as _persona_path import scheduler router = APIRouter(prefix="/distill") # Per-persona asyncio lock. Key: (user, persona) _LOCKS: dict[tuple, asyncio.Lock] = {} _LOCKS_META: dict[tuple, str] = {} # key → which step is currently running # Minimum time between successive runs of each endpoint, per persona. # Prevents accidental rapid-fire runs and token waste. _COOLDOWNS: dict[tuple, timedelta] = { "short": timedelta(minutes=1), "mid": timedelta(minutes=30), "long": timedelta(hours=6), "all": timedelta(hours=1), "rebuild": timedelta(hours=6), } _LAST_RUN: dict[tuple, datetime] = {} # key: (user, persona, endpoint) def _get_lock(user: str, persona: str) -> asyncio.Lock: key = (user, persona) if key not in _LOCKS: _LOCKS[key] = asyncio.Lock() return _LOCKS[key] def _resolve(user: str, persona: str) -> tuple[str, str]: try: u, p = validate_persona(user, persona) except Exception: raise HTTPException(status_code=404, detail=f"Persona not found: {user}/{persona}") set_context(u, p) return u, p def _check_lock(user: str, persona: str) -> asyncio.Lock: """Return the lock if free, raise 409 if already held.""" lock = _get_lock(user, persona) if lock.locked(): step = _LOCKS_META.get((user, persona), "distillation") raise HTTPException( status_code=409, detail=f"A {step} is already running for {persona} — please wait for it to finish.", ) return lock def _check_cooldown(user: str, persona: str, endpoint: str) -> None: """Raise 429 if the endpoint was run too recently for this persona.""" cooldown = _COOLDOWNS.get(endpoint) if not cooldown: return key = (user, persona, endpoint) last = _LAST_RUN.get(key) if last: elapsed = datetime.now() - last if elapsed < cooldown: remaining = cooldown - elapsed mins = int(remaining.total_seconds() // 60) secs = int(remaining.total_seconds() % 60) wait = f"{mins}m {secs}s" if mins else f"{secs}s" raise HTTPException( status_code=429, detail=f"{endpoint} was just run — please wait {wait} before running again.", ) def _record_run(user: str, persona: str, endpoint: str) -> None: _LAST_RUN[(user, persona, endpoint)] = datetime.now() @router.get("/status") async def distill_status() -> dict: from config import settings # Include which personas are currently distilling active = [f"{u}/{p}" for (u, p), lock in _LOCKS.items() if lock.locked()] return { "enabled": settings.auto_distill, "jobs": scheduler.status(), "active": active, "config": { "short": settings.auto_distill_short, "mid": settings.auto_distill_mid, "long": settings.auto_distill_long, }, } @router.post("/short") async def do_distill_short( user: str = Query(...), persona: str = Query(...), ) -> dict: u, p = _resolve(user, persona) _check_cooldown(u, p, "short") lock = _check_lock(u, p) async with lock: _LOCKS_META[(u, p)] = "short distill" try: result = distill_short(u, p) _record_run(u, p, "short") return {"ok": True, **result} finally: _LOCKS_META.pop((u, p), None) @router.post("/mid") async def do_distill_mid( user: str = Query(...), persona: str = Query(...), ) -> dict: u, p = _resolve(user, persona) _check_cooldown(u, p, "mid") lock = _check_lock(u, p) async with lock: _LOCKS_META[(u, p)] = "mid distill" try: result = await distill_mid(u, p) if "error" not in result: _record_run(u, p, "mid") return {"ok": "error" not in result, **result} finally: _LOCKS_META.pop((u, p), None) @router.post("/long") async def do_distill_long( user: str = Query(...), persona: str = Query(...), ) -> dict: u, p = _resolve(user, persona) _check_cooldown(u, p, "long") lock = _check_lock(u, p) async with lock: _LOCKS_META[(u, p)] = "long distill" try: result = await distill_long(u, p) if "error" not in result: _record_run(u, p, "long") return {"ok": "error" not in result, **result} finally: _LOCKS_META.pop((u, p), None) @router.post("/all") async def do_distill_all( user: str = Query(...), persona: str = Query(...), ) -> dict: u, p = _resolve(user, persona) _check_cooldown(u, p, "all") lock = _check_lock(u, p) async with lock: _LOCKS_META[(u, p)] = "full distill" try: short_result = distill_short(u, p) mid_result = await distill_mid(u, p) if "error" in mid_result: return {"ok": False, "short": short_result, "mid": mid_result} long_result = await distill_long(u, p) ok = "error" not in long_result if ok: _record_run(u, p, "all") return { "ok": ok, "short": short_result, "mid": mid_result, "long": long_result, } finally: _LOCKS_META.pop((u, p), None) @router.post("/rebuild") async def do_distill_rebuild( user: str = Query(...), persona: str = Query(...), ) -> dict: # noqa: E501 """Wipe MEMORY_MID and MEMORY_LONG (with backups), then run short → mid → long. Use when memories have drifted, been corrupted, or you want a clean slate rebuilt purely from session logs. Hand-edited content will be replaced. """ u, p = _resolve(user, persona) _check_cooldown(u, p, "rebuild") lock = _check_lock(u, p) async with lock: _LOCKS_META[(u, p)] = "memory rebuild" try: from memory_distiller import _rotate_backup, _read inara_dir = _persona_path(u, p) # Back up then wipe mid and long before rebuilding for name in ("MEMORY_MID.md", "MEMORY_LONG.md"): path = inara_dir / name if path.exists(): _rotate_backup(path) path.write_text( f"# {name}\n\n*Cleared for rebuild — {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M')}.*\n" ) short_result = distill_short(u, p) mid_result = await distill_mid(u, p) if "error" in mid_result: return {"ok": False, "short": short_result, "mid": mid_result, "rebuilt": True} long_result = await distill_long(u, p) ok = "error" not in long_result if ok: _record_run(u, p, "rebuild") return { "ok": ok, "short": short_result, "mid": mid_result, "long": long_result, "rebuilt": True, } finally: _LOCKS_META.pop((u, p), None)