diff --git a/cortex/config.py b/cortex/config.py index a113144..fd67cc7 100644 --- a/cortex/config.py +++ b/cortex/config.py @@ -65,6 +65,20 @@ class Settings(BaseSettings): distill_backend_mid: str = "" distill_backend_long: str = "" + # Model registry: default backend type per role when user registry has no entry. + # Values: "claude_cli" | "gemini_cli" | "gemini_api" (builtin IDs) + # Override in .env: ROLE_CHAT=claude_cli ROLE_DISTILL=gemini_api etc. + role_chat: str = "claude_cli" + role_orchestrator: str = "gemini_api" + role_distill: str = "claude_cli" + role_coder: str = "claude_cli" + role_research: str = "gemini_api" + + # Comma-separated list of standard roles shown in the model settings UI. + # Add custom roles here to extend the UI without code changes. + # Example: DEFINED_ROLES=chat,orchestrator,distill,coder,research,medical + defined_roles: str = "chat,orchestrator,distill,coder,research" + # Memory tier token budgets — soft caps used during distillation # Override in .env: MEMORY_BUDGET_LONG=4000 etc. memory_budget_long: int = 2000 @@ -90,6 +104,14 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") + def get_defined_roles(self) -> list[str]: + """Return the ordered list of standard roles from the defined_roles setting.""" + return [r.strip() for r in self.defined_roles.split(",") if r.strip()] + + def get_role_default(self, role: str) -> str: + """Return the .env default backend type for a role (e.g. 'claude_cli').""" + return getattr(self, f"role_{role.replace('-', '_')}", "claude_cli") + def home_root(self) -> Path: """Resolve home_dir relative to this file's location if not absolute.""" if self.home_dir.is_absolute(): diff --git a/cortex/cron_runner.py b/cortex/cron_runner.py index b051b11..7af7107 100644 --- a/cortex/cron_runner.py +++ b/cortex/cron_runner.py @@ -181,6 +181,7 @@ async def run_job(job: dict) -> None: response_text, backend = await complete( system_prompt=system_prompt, messages=[{"role": "user", "content": payload}], + role="chat", ) await notify(username, response_text, channel=channel) logger.info("cron [brief] sent via %s: %s", backend, label) diff --git a/cortex/llm_client.py b/cortex/llm_client.py index b9d0ad8..f84660e 100644 --- a/cortex/llm_client.py +++ b/cortex/llm_client.py @@ -31,7 +31,16 @@ async def cleanup() -> None: _active_pgroups.clear() -_BACKENDS = ("claude", "gemini", "local") +# Map from registry model type → dispatch function key +_TYPE_TO_BACKEND = { + "claude_cli": "claude", + "gemini_cli": "gemini", + "gemini_api": "gemini", # gemini_api falls back to CLI in this context + "local_openai": "local", +} + +# Explicit UI toggle values (kept for backward compat) +_EXPLICIT_BACKENDS = ("claude", "gemini", "local") _FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"} @@ -39,18 +48,42 @@ async def complete( system_prompt: str, messages: list[dict], model: str | None = None, + role: str = "chat", max_tokens: int = 2048, ) -> tuple[str, str]: - """Returns (response_text, actual_backend_used).""" - if model in _BACKENDS: + """ + Returns (response_text, actual_backend_used). + + model: explicit backend override ("claude" | "gemini" | "local") from UI toggle. + None = resolve via model registry for the given role. + role: registry role used when model is None (default: "chat"). + """ + import model_registry as _reg + from persona import _user + + username = _user.get() + resolved_cfg: dict | None = None + + if model in _EXPLICIT_BACKENDS: + # User explicitly selected a backend in the UI + if model == "local": + resolved_cfg = _reg.get_best_local_model(username, role) + if not resolved_cfg: + raise RuntimeError("No local model configured — add one at /settings/models") primary = model else: - primary = settings.primary_backend + # Role-based routing via model registry + resolved = _reg.get_model_for_role(username, role) + if resolved: + resolved_cfg = resolved + primary = _TYPE_TO_BACKEND.get(resolved["type"], "claude") + else: + primary = settings.primary_backend fallback = _FALLBACK.get(primary, "claude") try: - response = await _dispatch(primary, system_prompt, messages, model) + response = await _dispatch(primary, system_prompt, messages, resolved_cfg) return response, primary except Exception as e: err_str = str(e) @@ -65,13 +98,13 @@ async def _dispatch( backend: str, system_prompt: str, messages: list[dict], - model: str | None, + model_cfg: dict | None, ) -> str: if backend == "gemini": return await _gemini(system_prompt, messages) if backend == "local": - return await _local(system_prompt, messages) - return await _claude(system_prompt, messages, model) + return await _local(system_prompt, messages, model_cfg) + return await _claude(system_prompt, messages, model_cfg) def _fresh_claude_token() -> str | None: @@ -91,14 +124,16 @@ def _fresh_claude_token() -> str | None: return None -async def _claude(system_prompt: str, messages: list[dict], model: str | None) -> str: +async def _claude(system_prompt: str, messages: list[dict], model_cfg: dict | None) -> str: + model_name = (model_cfg or {}).get("model_name") if model_cfg else None cmd = [ "claude", "--print", "--no-session-persistence", "--output-format", "text", ] - if model and model not in ("claude", "gemini"): - cmd.extend(["--model", model]) + # Only pass --model if it's a real model name (not a backend type string) + if model_name and model_name not in ("claude", "gemini", "local", ""): + cmd.extend(["--model", model_name]) if system_prompt: cmd.extend(["--system-prompt", system_prompt]) cmd.append(_build_conversation(messages)) @@ -114,19 +149,22 @@ async def _claude(system_prompt: str, messages: list[dict], model: str | None) - return await _run(cmd, timeout=settings.timeout_claude, env=env) -async def _local(system_prompt: str, messages: list[dict]) -> str: +async def _local(system_prompt: str, messages: list[dict], model_cfg: dict | None = None) -> str: """OpenAI-compatible backend — Open WebUI / Ollama. - Per-user config (home/{user}/local_llm.json) takes precedence over - the server-level .env defaults. + model_cfg is pre-resolved by complete() via model_registry. + Falls back to registry lookup if not provided. """ import httpx - from persona import _user - from user_settings import get_active_local_model - cfg = get_active_local_model(_user.get()) + cfg = model_cfg if not cfg: - raise RuntimeError("No local model configured — add one at /settings/local") + # Fallback: resolve directly from registry + import model_registry as _reg + from persona import _user + cfg = _reg.get_best_local_model(_user.get()) + if not cfg: + raise RuntimeError("No local model configured — add one at /settings/models") api_url = cfg["api_url"] api_key = cfg["api_key"] diff --git a/cortex/memory_distiller.py b/cortex/memory_distiller.py index 0c2b555..a68fca0 100644 --- a/cortex/memory_distiller.py +++ b/cortex/memory_distiller.py @@ -92,7 +92,6 @@ async def distill_mid(username: str | None = None, persona: str | None = None) - if not short_content.strip() or "Not yet populated" in short_content: return {"error": "MEMORY_SHORT.md is empty — run distill/short first"} - backend_override = settings.distill_backend_mid or None budget_tokens = settings.memory_budget_mid system_prompt = ( f"You are {settings.agent_name}'s memory distillation system. " @@ -107,7 +106,7 @@ async def distill_mid(username: str | None = None, persona: str | None = None) - response_text, backend = await complete( system_prompt=system_prompt, messages=[{"role": "user", "content": short_content}], - model=backend_override, + role="distill", ) now = datetime.now().strftime("%Y-%m-%d %H:%M") @@ -146,7 +145,6 @@ async def distill_long(username: str | None = None, persona: str | None = None) if not mid_content.strip() or "Not yet populated" in mid_content: return {"error": "MEMORY_MID.md is empty — run distill/mid first"} - backend_override = settings.distill_backend_long or None budget_tokens = settings.memory_budget_long system_prompt = ( f"You are {settings.agent_name}'s long-term memory curator. " @@ -165,7 +163,7 @@ async def distill_long(username: str | None = None, persona: str | None = None) response_text, backend = await complete( system_prompt=system_prompt, messages=[{"role": "user", "content": user_content}], - model=backend_override, + role="distill", ) # Ensure the file has the right header if the LLM dropped it diff --git a/cortex/model_registry.py b/cortex/model_registry.py new file mode 100644 index 0000000..16eca45 --- /dev/null +++ b/cortex/model_registry.py @@ -0,0 +1,437 @@ +""" +Per-user unified model registry. + +Stored in: home/{user}/model_registry.json + +Schema: + { + "version": 1, + "hosts": [{"id", "label", "api_url", "api_key"}, ...], + "models": [ + { + "id": str, # unique within this registry + "type": str, # "local_openai" | "claude_cli" | "gemini_cli" | "gemini_api" + "label": str, # human-readable display name + "model_name": str, # model identifier sent to the API + "host_id": str | null, # only for local_openai — references hosts[].id + "context_k": int, # context window in thousands of tokens (informational) + "tags": [str], # user-defined capability tags + }, + ], + "roles": { + "": { + "primary": "" | null, + "backup_1": "" | null, + "backup_2": "" | null, + "backup_3": "" | null, + "backup_4": "" | null, + }, + }, + } + +Built-in model IDs (always resolvable, no registry entry required): + "claude_cli" — Claude CLI subprocess (~/.claude/.credentials.json) + "gemini_cli" — Gemini CLI subprocess + "gemini_api" — Gemini API (google-genai SDK; used by orchestrator engine, not llm_client) + +Standard roles are defined by settings.defined_roles (default: chat,orchestrator,distill,coder,research). +Additional custom roles can be added freely to roles{}. + +Resolution for get_model_for_role(username, role): + 1. User registry: roles[role].primary → backup_1 → backup_2 → backup_3 → backup_4 + 2. .env default: ROLE_= (e.g. ROLE_CHAT=claude_cli) + 3. Hardcoded last-resort defaults per role +""" + +import json +import logging +import secrets +from pathlib import Path + +from config import settings + +logger = logging.getLogger(__name__) + +# ── Built-in model definitions ──────────────────────────────────────────────── +# These IDs are always resolvable without a registry entry. + +def _builtins() -> dict[str, dict]: + """Return built-in model definitions (lazy so settings are resolved at call time).""" + return { + "claude_cli": { + "id": "claude_cli", + "type": "claude_cli", + "label": f"Claude (CLI) — {settings.default_model}", + "model_name": settings.default_model, + "context_k": 200, + "tags": ["chat", "persona", "creative"], + }, + "gemini_cli": { + "id": "gemini_cli", + "type": "gemini_cli", + "label": "Gemini (CLI)", + "model_name": "", + "context_k": 1000, + "tags": ["chat", "research", "long_context"], + }, + "gemini_api": { + "id": "gemini_api", + "type": "gemini_api", + "label": f"Gemini API — {settings.orchestrator_model}", + "model_name": settings.orchestrator_model, + "context_k": 1000, + "tags": ["orchestrator", "research", "long_context", "tools"], + }, + } + + +# Hardcoded last-resort defaults per role (used only if .env is also unset) +_ROLE_LAST_RESORT: dict[str, str] = { + "chat": "claude_cli", + "orchestrator": "gemini_api", + "distill": "claude_cli", + "coder": "claude_cli", + "research": "gemini_api", +} + +PRIORITY_KEYS = ["primary", "backup_1", "backup_2", "backup_3", "backup_4"] + + +# ── Storage ─────────────────────────────────────────────────────────────────── + +def _registry_path(username: str) -> Path: + return settings.home_root() / username / "model_registry.json" + + +def _local_llm_path(username: str) -> Path: + return settings.home_root() / username / "local_llm.json" + + +def _empty() -> dict: + return {"version": 1, "hosts": [], "models": [], "roles": {}} + + +def _load(username: str) -> dict: + path = _registry_path(username) + if path.exists(): + try: + data = json.loads(path.read_text()) + if isinstance(data, dict) and "version" in data: + return data + except (json.JSONDecodeError, OSError): + logger.warning("model_registry.json for %s is unreadable — starting fresh", username) + return _empty() + + # No registry yet — try migrating from local_llm.json + legacy = _local_llm_path(username) + if legacy.exists(): + data = _migrate_from_local_llm(username, legacy) + _save(username, data) + logger.info("Migrated local_llm.json → model_registry.json for %s", username) + return data + + return _empty() + + +def _save(username: str, data: dict) -> None: + _registry_path(username).write_text(json.dumps(data, indent=2)) + + +# ── Migration ───────────────────────────────────────────────────────────────── + +def _migrate_from_local_llm(username: str, path: Path) -> dict: + """Convert local_llm.json (hosts/models/active_model_id) → model_registry format.""" + try: + old = json.loads(path.read_text()) + except Exception: + return _empty() + + data = _empty() + + # Handle v0 flat format + if "hosts" not in old: + api_url = old.get("api_url") or settings.local_api_url + api_key = old.get("api_key") or settings.local_api_key + model_name = old.get("model") or settings.local_model + if not api_url: + return data + host_id = secrets.token_hex(4) + old = { + "hosts": [{"id": host_id, "label": "Local Model Server", "api_url": api_url, "api_key": api_key}], + "models": [{"id": secrets.token_hex(4), "host_id": host_id, "label": model_name, "model_name": model_name}] if model_name else [], + "active_model_id": None, + } + if old["models"]: + old["active_model_id"] = old["models"][0]["id"] + + data["hosts"] = old.get("hosts", []) + + for m in old.get("models", []): + data["models"].append({ + "id": m["id"], + "type": "local_openai", + "label": m.get("label") or m.get("model_name", ""), + "model_name": m.get("model_name", ""), + "host_id": m.get("host_id"), + "context_k": 0, + "tags": [], + }) + + # Build initial role assignments + active_id = old.get("active_model_id") + distill_type = settings.distill_backend_mid or None + + roles: dict[str, dict] = {} + if active_id and any(m["id"] == active_id for m in data["models"]): + roles["chat"] = {"primary": active_id} + + if distill_type == "local" and active_id: + roles["distill"] = {"primary": active_id} + + data["roles"] = roles + return data + + +# ── Model resolution ────────────────────────────────────────────────────────── + +def _resolve_model(registry: dict, model_id: str) -> dict | None: + """Resolve a model_id to its full config dict, or None if not found.""" + builtins = _builtins() + + # Built-in IDs take priority over user-defined entries with the same ID + if model_id in builtins: + return dict(builtins[model_id]) + + model = next((m for m in registry.get("models", []) if m["id"] == model_id), None) + if not model: + return None + + if model.get("type") == "local_openai": + host_id = model.get("host_id") + host = next((h for h in registry.get("hosts", []) if h["id"] == host_id), None) + if not host: + logger.warning("model %s references missing host_id %s", model_id, host_id) + return None + return {**model, "api_url": host.get("api_url", ""), "api_key": host.get("api_key", "")} + + return dict(model) + + +def get_model_for_role(username: str, role: str) -> dict | None: + """ + Return the resolved model config for the given role. + + Resolution order: + 1. User registry: roles[role].primary → backup_1 → ... → backup_4 + 2. .env: ROLE_ = builtin model ID + 3. Hardcoded last-resort default per role + 4. claude_cli (absolute fallback) + """ + registry = _load(username) + role_cfg = registry.get("roles", {}).get(role, {}) + + for key in PRIORITY_KEYS: + model_id = role_cfg.get(key) + if not model_id: + continue + resolved = _resolve_model(registry, model_id) + if resolved: + return resolved + logger.debug("role %s.%s = %s but model not found", role, key, model_id) + + # .env default + env_type = settings.get_role_default(role) + builtins = _builtins() + if env_type and env_type in builtins: + return dict(builtins[env_type]) + + # Hardcoded last resort + fallback_id = _ROLE_LAST_RESORT.get(role, "claude_cli") + return dict(builtins.get(fallback_id, builtins["claude_cli"])) + + +def get_best_local_model(username: str, role: str = "chat") -> dict | None: + """ + Return the best available local_openai model for the given role. + Used when the user explicitly selects "local" backend in the UI. + Tries the role's priority chain first, then any configured local model. + """ + registry = _load(username) + role_cfg = registry.get("roles", {}).get(role, {}) + + for key in PRIORITY_KEYS: + model_id = role_cfg.get(key) + if not model_id: + continue + resolved = _resolve_model(registry, model_id) + if resolved and resolved.get("type") == "local_openai": + return resolved + + # Fall back to first configured local model + for model in registry.get("models", []): + if model.get("type") == "local_openai": + resolved = _resolve_model(registry, model["id"]) + if resolved: + return resolved + + return None + + +# ── Read API (for UI and callers) ───────────────────────────────────────────── + +def get_registry(username: str) -> dict: + """Return the full registry (with built-in models injected for display).""" + return _load(username) + + +def get_all_models(username: str) -> list[dict]: + """Return all user-defined models (resolved — hosts merged in).""" + registry = _load(username) + out = [] + for m in registry.get("models", []): + resolved = _resolve_model(registry, m["id"]) + if resolved: + out.append(resolved) + return out + + +def get_defined_roles(username: str) -> dict[str, dict]: + """Return the roles section of the registry, filling gaps with empty dicts.""" + registry = _load(username) + roles = registry.get("roles", {}) + result = {} + for role in settings.get_defined_roles(): + result[role] = roles.get(role, {}) + return result + + +# ── Write API (CRUD) ────────────────────────────────────────────────────────── + +def save_host(username: str, host_id: str | None, + label: str, api_url: str, api_key: str) -> str: + """Create or update a host. Returns the host ID.""" + data = _load(username) + + if host_id: + for h in data["hosts"]: + if h["id"] == host_id: + h["label"] = label.strip() + h["api_url"] = api_url.strip() + if api_key.strip(): + h["api_key"] = api_key.strip() + _save(username, data) + return host_id + host_id = None # not found — create new + + host_id = secrets.token_hex(4) + data["hosts"].append({ + "id": host_id, + "label": label.strip(), + "api_url": api_url.strip(), + "api_key": api_key.strip(), + }) + _save(username, data) + return host_id + + +def remove_host(username: str, host_id: str) -> bool: + """Remove a host and all models that reference it. Returns True if found.""" + data = _load(username) + before = len(data["hosts"]) + data["hosts"] = [h for h in data["hosts"] if h["id"] != host_id] + data["models"] = [m for m in data["models"] if m.get("host_id") != host_id] + # Clear any role assignments that pointed to removed models + removed_ids = {m["id"] for m in data["models"] if m.get("host_id") == host_id} + for role_cfg in data.get("roles", {}).values(): + for key in PRIORITY_KEYS: + if role_cfg.get(key) in removed_ids: + role_cfg[key] = None + _save(username, data) + return len(data["hosts"]) < before + + +def save_model(username: str, model_id: str | None, host_id: str, + label: str, model_name: str, context_k: int = 0, + tags: list[str] | None = None) -> str: + """Create or update a model entry. Returns the model ID.""" + data = _load(username) + tags = tags or [] + + if model_id: + for m in data["models"]: + if m["id"] == model_id: + m["host_id"] = host_id + m["label"] = label.strip() or model_name.strip() + m["model_name"] = model_name.strip() + m["context_k"] = context_k + m["tags"] = tags + _save(username, data) + return model_id + model_id = None + + model_id = secrets.token_hex(4) + data["models"].append({ + "id": model_id, + "type": "local_openai", + "label": label.strip() or model_name.strip(), + "model_name": model_name.strip(), + "host_id": host_id, + "context_k": context_k, + "tags": tags, + }) + _save(username, data) + return model_id + + +def remove_model(username: str, model_id: str) -> bool: + """Remove a model and clear any role assignments pointing to it.""" + data = _load(username) + before = len(data["models"]) + data["models"] = [m for m in data["models"] if m["id"] != model_id] + + for role_cfg in data.get("roles", {}).values(): + for key in PRIORITY_KEYS: + if role_cfg.get(key) == model_id: + role_cfg[key] = None + + _save(username, data) + return len(data["models"]) < before + + +def set_role(username: str, role: str, priority: str, model_id: str | None) -> bool: + """ + Assign a model to a role priority slot. + + priority must be one of: primary, backup_1, backup_2, backup_3, backup_4 + model_id None clears the slot. + model_id "claude_cli" / "gemini_cli" / "gemini_api" are valid built-in IDs. + Returns False if model_id is set but not found. + """ + if priority not in PRIORITY_KEYS: + return False + + data = _load(username) + + if model_id and model_id not in _builtins(): + if not any(m["id"] == model_id for m in data["models"]): + return False + + roles = data.setdefault("roles", {}) + if role not in roles: + roles[role] = {} + roles[role][priority] = model_id or None + + _save(username, data) + return True + + +def fetch_models_from_host(api_url: str, api_key: str) -> list[str]: + """Synchronously fetch the model list from an OpenAI-compatible host.""" + import httpx + url = api_url.rstrip("/") + "/api/models" + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + resp = httpx.get(url, headers=headers, timeout=10) + resp.raise_for_status() + data = resp.json() + models = data.get("data", []) + return sorted(m.get("id", m.get("name", "")) for m in models if m.get("id") or m.get("name")) diff --git a/cortex/routers/auth.py b/cortex/routers/auth.py index fb1bfd6..3167edd 100644 --- a/cortex/routers/auth.py +++ b/cortex/routers/auth.py @@ -72,21 +72,33 @@ def _gemini_status() -> dict: return {"ok": False, "error": str(e), "warning": True, "authenticated": False} -async def _local_status() -> dict: - if not settings.local_api_url: +async def _local_status(username: str = "scott") -> dict: + """Check reachability of the user's configured local model host.""" + import model_registry + cfg = model_registry.get_best_local_model(username) + if not cfg: + return {"configured": False} + api_url = cfg.get("api_url", "") + if not api_url: return {"configured": False} try: import httpx - url = settings.local_api_url.rstrip("/") + "/api/models" + url = api_url.rstrip("/") + "/api/models" headers = {} - if settings.local_api_key: - headers["Authorization"] = f"Bearer {settings.local_api_key}" + api_key = cfg.get("api_key", "") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" async with httpx.AsyncClient(timeout=5) as client: resp = await client.get(url, headers=headers) reachable = resp.status_code < 400 - return {"configured": True, "reachable": reachable, "model": settings.local_model} + return { + "configured": True, + "reachable": reachable, + "model": cfg.get("model_name", ""), + "label": cfg.get("label", ""), + } except Exception as e: - return {"configured": True, "reachable": False, "error": str(e), "model": settings.local_model} + return {"configured": True, "reachable": False, "error": str(e), "model": cfg.get("model_name", "")} @router.get("/status") diff --git a/cortex/routers/chat.py b/cortex/routers/chat.py index d78c3c3..14cc3fb 100644 --- a/cortex/routers/chat.py +++ b/cortex/routers/chat.py @@ -11,7 +11,7 @@ from session_store import load as load_session, save as save_session, list_all, from config import settings from persona import set_context, validate as validate_persona from auth_utils import COOKIE_NAME, decode_token -import user_settings +import model_registry import event_bus @@ -138,15 +138,15 @@ _BACKEND_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"} def _local_model_info(request: Request) -> dict | None: - """Return active local model {label, model_name} for the session user, or None.""" + """Return the best local model {label, model_name} for the session user, or None.""" try: token = request.cookies.get(COOKIE_NAME) username = decode_token(token) if token else None if not username: return None - cfg = user_settings.get_active_local_model(username) + cfg = model_registry.get_best_local_model(username, "chat") if cfg: - return {"label": cfg["label"], "model_name": cfg["model_name"]} + return {"label": cfg.get("label", ""), "model_name": cfg.get("model_name", "")} except (jwt.InvalidTokenError, Exception): pass return None