feat: unified model registry with role-based routing
Introduces model_registry.py as the single source of truth for all LLM
backend configuration. Replaces scattered backend settings across user_settings,
config distill_backend_*, and the UI toggle.
model_registry.py:
- Per-user home/{user}/model_registry.json with version, hosts, models, roles
- Models have: type (local_openai|claude_cli|gemini_cli|gemini_api), label,
model_name, host_id, context_k (tokens), tags (capability labels)
- Roles map to priority chains: primary, backup_1..backup_4
- Built-in IDs (claude_cli, gemini_cli, gemini_api) always resolvable
- Auto-migrates existing local_llm.json on first access
- CRUD: save_host, remove_host, save_model, remove_model, set_role
- get_model_for_role(): registry → .env default → hardcoded fallback
config.py:
- role_chat/orchestrator/distill/coder/research .env defaults
- defined_roles: comma-separated standard role list (extensible)
- get_defined_roles() and get_role_default() helper methods
llm_client.complete():
- New role= parameter (default "chat") for registry-based routing
- model= still accepted for explicit UI toggle override
- _claude() and _local() accept model_cfg dict instead of raw string
- _local() uses pre-resolved config from registry
memory_distiller.py:
- distill_mid/long now use role="distill" (no more distill_backend_* .env vars needed)
cron_runner.py:
- brief jobs use role="chat"
routers/chat.py + auth.py:
- Use model_registry instead of user_settings for local model info
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -65,6 +65,20 @@ class Settings(BaseSettings):
|
||||
distill_backend_mid: str = ""
|
||||
distill_backend_long: str = ""
|
||||
|
||||
# Model registry: default backend type per role when user registry has no entry.
|
||||
# Values: "claude_cli" | "gemini_cli" | "gemini_api" (builtin IDs)
|
||||
# Override in .env: ROLE_CHAT=claude_cli ROLE_DISTILL=gemini_api etc.
|
||||
role_chat: str = "claude_cli"
|
||||
role_orchestrator: str = "gemini_api"
|
||||
role_distill: str = "claude_cli"
|
||||
role_coder: str = "claude_cli"
|
||||
role_research: str = "gemini_api"
|
||||
|
||||
# Comma-separated list of standard roles shown in the model settings UI.
|
||||
# Add custom roles here to extend the UI without code changes.
|
||||
# Example: DEFINED_ROLES=chat,orchestrator,distill,coder,research,medical
|
||||
defined_roles: str = "chat,orchestrator,distill,coder,research"
|
||||
|
||||
# Memory tier token budgets — soft caps used during distillation
|
||||
# Override in .env: MEMORY_BUDGET_LONG=4000 etc.
|
||||
memory_budget_long: int = 2000
|
||||
@@ -90,6 +104,14 @@ class Settings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
def get_defined_roles(self) -> list[str]:
|
||||
"""Return the ordered list of standard roles from the defined_roles setting."""
|
||||
return [r.strip() for r in self.defined_roles.split(",") if r.strip()]
|
||||
|
||||
def get_role_default(self, role: str) -> str:
|
||||
"""Return the .env default backend type for a role (e.g. 'claude_cli')."""
|
||||
return getattr(self, f"role_{role.replace('-', '_')}", "claude_cli")
|
||||
|
||||
def home_root(self) -> Path:
|
||||
"""Resolve home_dir relative to this file's location if not absolute."""
|
||||
if self.home_dir.is_absolute():
|
||||
|
||||
@@ -181,6 +181,7 @@ async def run_job(job: dict) -> None:
|
||||
response_text, backend = await complete(
|
||||
system_prompt=system_prompt,
|
||||
messages=[{"role": "user", "content": payload}],
|
||||
role="chat",
|
||||
)
|
||||
await notify(username, response_text, channel=channel)
|
||||
logger.info("cron [brief] sent via %s: %s", backend, label)
|
||||
|
||||
@@ -31,7 +31,16 @@ async def cleanup() -> None:
|
||||
_active_pgroups.clear()
|
||||
|
||||
|
||||
_BACKENDS = ("claude", "gemini", "local")
|
||||
# Map from registry model type → dispatch function key
|
||||
_TYPE_TO_BACKEND = {
|
||||
"claude_cli": "claude",
|
||||
"gemini_cli": "gemini",
|
||||
"gemini_api": "gemini", # gemini_api falls back to CLI in this context
|
||||
"local_openai": "local",
|
||||
}
|
||||
|
||||
# Explicit UI toggle values (kept for backward compat)
|
||||
_EXPLICIT_BACKENDS = ("claude", "gemini", "local")
|
||||
_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"}
|
||||
|
||||
|
||||
@@ -39,18 +48,42 @@ async def complete(
|
||||
system_prompt: str,
|
||||
messages: list[dict],
|
||||
model: str | None = None,
|
||||
role: str = "chat",
|
||||
max_tokens: int = 2048,
|
||||
) -> tuple[str, str]:
|
||||
"""Returns (response_text, actual_backend_used)."""
|
||||
if model in _BACKENDS:
|
||||
"""
|
||||
Returns (response_text, actual_backend_used).
|
||||
|
||||
model: explicit backend override ("claude" | "gemini" | "local") from UI toggle.
|
||||
None = resolve via model registry for the given role.
|
||||
role: registry role used when model is None (default: "chat").
|
||||
"""
|
||||
import model_registry as _reg
|
||||
from persona import _user
|
||||
|
||||
username = _user.get()
|
||||
resolved_cfg: dict | None = None
|
||||
|
||||
if model in _EXPLICIT_BACKENDS:
|
||||
# User explicitly selected a backend in the UI
|
||||
if model == "local":
|
||||
resolved_cfg = _reg.get_best_local_model(username, role)
|
||||
if not resolved_cfg:
|
||||
raise RuntimeError("No local model configured — add one at /settings/models")
|
||||
primary = model
|
||||
else:
|
||||
primary = settings.primary_backend
|
||||
# Role-based routing via model registry
|
||||
resolved = _reg.get_model_for_role(username, role)
|
||||
if resolved:
|
||||
resolved_cfg = resolved
|
||||
primary = _TYPE_TO_BACKEND.get(resolved["type"], "claude")
|
||||
else:
|
||||
primary = settings.primary_backend
|
||||
|
||||
fallback = _FALLBACK.get(primary, "claude")
|
||||
|
||||
try:
|
||||
response = await _dispatch(primary, system_prompt, messages, model)
|
||||
response = await _dispatch(primary, system_prompt, messages, resolved_cfg)
|
||||
return response, primary
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
@@ -65,13 +98,13 @@ async def _dispatch(
|
||||
backend: str,
|
||||
system_prompt: str,
|
||||
messages: list[dict],
|
||||
model: str | None,
|
||||
model_cfg: dict | None,
|
||||
) -> str:
|
||||
if backend == "gemini":
|
||||
return await _gemini(system_prompt, messages)
|
||||
if backend == "local":
|
||||
return await _local(system_prompt, messages)
|
||||
return await _claude(system_prompt, messages, model)
|
||||
return await _local(system_prompt, messages, model_cfg)
|
||||
return await _claude(system_prompt, messages, model_cfg)
|
||||
|
||||
|
||||
def _fresh_claude_token() -> str | None:
|
||||
@@ -91,14 +124,16 @@ def _fresh_claude_token() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
async def _claude(system_prompt: str, messages: list[dict], model: str | None) -> str:
|
||||
async def _claude(system_prompt: str, messages: list[dict], model_cfg: dict | None) -> str:
|
||||
model_name = (model_cfg or {}).get("model_name") if model_cfg else None
|
||||
cmd = [
|
||||
"claude", "--print",
|
||||
"--no-session-persistence",
|
||||
"--output-format", "text",
|
||||
]
|
||||
if model and model not in ("claude", "gemini"):
|
||||
cmd.extend(["--model", model])
|
||||
# Only pass --model if it's a real model name (not a backend type string)
|
||||
if model_name and model_name not in ("claude", "gemini", "local", ""):
|
||||
cmd.extend(["--model", model_name])
|
||||
if system_prompt:
|
||||
cmd.extend(["--system-prompt", system_prompt])
|
||||
cmd.append(_build_conversation(messages))
|
||||
@@ -114,19 +149,22 @@ async def _claude(system_prompt: str, messages: list[dict], model: str | None) -
|
||||
return await _run(cmd, timeout=settings.timeout_claude, env=env)
|
||||
|
||||
|
||||
async def _local(system_prompt: str, messages: list[dict]) -> str:
|
||||
async def _local(system_prompt: str, messages: list[dict], model_cfg: dict | None = None) -> str:
|
||||
"""OpenAI-compatible backend — Open WebUI / Ollama.
|
||||
|
||||
Per-user config (home/{user}/local_llm.json) takes precedence over
|
||||
the server-level .env defaults.
|
||||
model_cfg is pre-resolved by complete() via model_registry.
|
||||
Falls back to registry lookup if not provided.
|
||||
"""
|
||||
import httpx
|
||||
from persona import _user
|
||||
from user_settings import get_active_local_model
|
||||
|
||||
cfg = get_active_local_model(_user.get())
|
||||
cfg = model_cfg
|
||||
if not cfg:
|
||||
raise RuntimeError("No local model configured — add one at /settings/local")
|
||||
# Fallback: resolve directly from registry
|
||||
import model_registry as _reg
|
||||
from persona import _user
|
||||
cfg = _reg.get_best_local_model(_user.get())
|
||||
if not cfg:
|
||||
raise RuntimeError("No local model configured — add one at /settings/models")
|
||||
|
||||
api_url = cfg["api_url"]
|
||||
api_key = cfg["api_key"]
|
||||
|
||||
@@ -92,7 +92,6 @@ async def distill_mid(username: str | None = None, persona: str | None = None) -
|
||||
if not short_content.strip() or "Not yet populated" in short_content:
|
||||
return {"error": "MEMORY_SHORT.md is empty — run distill/short first"}
|
||||
|
||||
backend_override = settings.distill_backend_mid or None
|
||||
budget_tokens = settings.memory_budget_mid
|
||||
system_prompt = (
|
||||
f"You are {settings.agent_name}'s memory distillation system. "
|
||||
@@ -107,7 +106,7 @@ async def distill_mid(username: str | None = None, persona: str | None = None) -
|
||||
response_text, backend = await complete(
|
||||
system_prompt=system_prompt,
|
||||
messages=[{"role": "user", "content": short_content}],
|
||||
model=backend_override,
|
||||
role="distill",
|
||||
)
|
||||
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
@@ -146,7 +145,6 @@ async def distill_long(username: str | None = None, persona: str | None = None)
|
||||
if not mid_content.strip() or "Not yet populated" in mid_content:
|
||||
return {"error": "MEMORY_MID.md is empty — run distill/mid first"}
|
||||
|
||||
backend_override = settings.distill_backend_long or None
|
||||
budget_tokens = settings.memory_budget_long
|
||||
system_prompt = (
|
||||
f"You are {settings.agent_name}'s long-term memory curator. "
|
||||
@@ -165,7 +163,7 @@ async def distill_long(username: str | None = None, persona: str | None = None)
|
||||
response_text, backend = await complete(
|
||||
system_prompt=system_prompt,
|
||||
messages=[{"role": "user", "content": user_content}],
|
||||
model=backend_override,
|
||||
role="distill",
|
||||
)
|
||||
|
||||
# Ensure the file has the right header if the LLM dropped it
|
||||
|
||||
437
cortex/model_registry.py
Normal file
437
cortex/model_registry.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""
|
||||
Per-user unified model registry.
|
||||
|
||||
Stored in: home/{user}/model_registry.json
|
||||
|
||||
Schema:
|
||||
{
|
||||
"version": 1,
|
||||
"hosts": [{"id", "label", "api_url", "api_key"}, ...],
|
||||
"models": [
|
||||
{
|
||||
"id": str, # unique within this registry
|
||||
"type": str, # "local_openai" | "claude_cli" | "gemini_cli" | "gemini_api"
|
||||
"label": str, # human-readable display name
|
||||
"model_name": str, # model identifier sent to the API
|
||||
"host_id": str | null, # only for local_openai — references hosts[].id
|
||||
"context_k": int, # context window in thousands of tokens (informational)
|
||||
"tags": [str], # user-defined capability tags
|
||||
},
|
||||
],
|
||||
"roles": {
|
||||
"<role>": {
|
||||
"primary": "<model_id>" | null,
|
||||
"backup_1": "<model_id>" | null,
|
||||
"backup_2": "<model_id>" | null,
|
||||
"backup_3": "<model_id>" | null,
|
||||
"backup_4": "<model_id>" | null,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Built-in model IDs (always resolvable, no registry entry required):
|
||||
"claude_cli" — Claude CLI subprocess (~/.claude/.credentials.json)
|
||||
"gemini_cli" — Gemini CLI subprocess
|
||||
"gemini_api" — Gemini API (google-genai SDK; used by orchestrator engine, not llm_client)
|
||||
|
||||
Standard roles are defined by settings.defined_roles (default: chat,orchestrator,distill,coder,research).
|
||||
Additional custom roles can be added freely to roles{}.
|
||||
|
||||
Resolution for get_model_for_role(username, role):
|
||||
1. User registry: roles[role].primary → backup_1 → backup_2 → backup_3 → backup_4
|
||||
2. .env default: ROLE_<ROLE>=<builtin_id> (e.g. ROLE_CHAT=claude_cli)
|
||||
3. Hardcoded last-resort defaults per role
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Built-in model definitions ────────────────────────────────────────────────
|
||||
# These IDs are always resolvable without a registry entry.
|
||||
|
||||
def _builtins() -> dict[str, dict]:
|
||||
"""Return built-in model definitions (lazy so settings are resolved at call time)."""
|
||||
return {
|
||||
"claude_cli": {
|
||||
"id": "claude_cli",
|
||||
"type": "claude_cli",
|
||||
"label": f"Claude (CLI) — {settings.default_model}",
|
||||
"model_name": settings.default_model,
|
||||
"context_k": 200,
|
||||
"tags": ["chat", "persona", "creative"],
|
||||
},
|
||||
"gemini_cli": {
|
||||
"id": "gemini_cli",
|
||||
"type": "gemini_cli",
|
||||
"label": "Gemini (CLI)",
|
||||
"model_name": "",
|
||||
"context_k": 1000,
|
||||
"tags": ["chat", "research", "long_context"],
|
||||
},
|
||||
"gemini_api": {
|
||||
"id": "gemini_api",
|
||||
"type": "gemini_api",
|
||||
"label": f"Gemini API — {settings.orchestrator_model}",
|
||||
"model_name": settings.orchestrator_model,
|
||||
"context_k": 1000,
|
||||
"tags": ["orchestrator", "research", "long_context", "tools"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Hardcoded last-resort defaults per role (used only if .env is also unset)
|
||||
_ROLE_LAST_RESORT: dict[str, str] = {
|
||||
"chat": "claude_cli",
|
||||
"orchestrator": "gemini_api",
|
||||
"distill": "claude_cli",
|
||||
"coder": "claude_cli",
|
||||
"research": "gemini_api",
|
||||
}
|
||||
|
||||
PRIORITY_KEYS = ["primary", "backup_1", "backup_2", "backup_3", "backup_4"]
|
||||
|
||||
|
||||
# ── Storage ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _registry_path(username: str) -> Path:
|
||||
return settings.home_root() / username / "model_registry.json"
|
||||
|
||||
|
||||
def _local_llm_path(username: str) -> Path:
|
||||
return settings.home_root() / username / "local_llm.json"
|
||||
|
||||
|
||||
def _empty() -> dict:
|
||||
return {"version": 1, "hosts": [], "models": [], "roles": {}}
|
||||
|
||||
|
||||
def _load(username: str) -> dict:
|
||||
path = _registry_path(username)
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
if isinstance(data, dict) and "version" in data:
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("model_registry.json for %s is unreadable — starting fresh", username)
|
||||
return _empty()
|
||||
|
||||
# No registry yet — try migrating from local_llm.json
|
||||
legacy = _local_llm_path(username)
|
||||
if legacy.exists():
|
||||
data = _migrate_from_local_llm(username, legacy)
|
||||
_save(username, data)
|
||||
logger.info("Migrated local_llm.json → model_registry.json for %s", username)
|
||||
return data
|
||||
|
||||
return _empty()
|
||||
|
||||
|
||||
def _save(username: str, data: dict) -> None:
|
||||
_registry_path(username).write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
# ── Migration ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def _migrate_from_local_llm(username: str, path: Path) -> dict:
|
||||
"""Convert local_llm.json (hosts/models/active_model_id) → model_registry format."""
|
||||
try:
|
||||
old = json.loads(path.read_text())
|
||||
except Exception:
|
||||
return _empty()
|
||||
|
||||
data = _empty()
|
||||
|
||||
# Handle v0 flat format
|
||||
if "hosts" not in old:
|
||||
api_url = old.get("api_url") or settings.local_api_url
|
||||
api_key = old.get("api_key") or settings.local_api_key
|
||||
model_name = old.get("model") or settings.local_model
|
||||
if not api_url:
|
||||
return data
|
||||
host_id = secrets.token_hex(4)
|
||||
old = {
|
||||
"hosts": [{"id": host_id, "label": "Local Model Server", "api_url": api_url, "api_key": api_key}],
|
||||
"models": [{"id": secrets.token_hex(4), "host_id": host_id, "label": model_name, "model_name": model_name}] if model_name else [],
|
||||
"active_model_id": None,
|
||||
}
|
||||
if old["models"]:
|
||||
old["active_model_id"] = old["models"][0]["id"]
|
||||
|
||||
data["hosts"] = old.get("hosts", [])
|
||||
|
||||
for m in old.get("models", []):
|
||||
data["models"].append({
|
||||
"id": m["id"],
|
||||
"type": "local_openai",
|
||||
"label": m.get("label") or m.get("model_name", ""),
|
||||
"model_name": m.get("model_name", ""),
|
||||
"host_id": m.get("host_id"),
|
||||
"context_k": 0,
|
||||
"tags": [],
|
||||
})
|
||||
|
||||
# Build initial role assignments
|
||||
active_id = old.get("active_model_id")
|
||||
distill_type = settings.distill_backend_mid or None
|
||||
|
||||
roles: dict[str, dict] = {}
|
||||
if active_id and any(m["id"] == active_id for m in data["models"]):
|
||||
roles["chat"] = {"primary": active_id}
|
||||
|
||||
if distill_type == "local" and active_id:
|
||||
roles["distill"] = {"primary": active_id}
|
||||
|
||||
data["roles"] = roles
|
||||
return data
|
||||
|
||||
|
||||
# ── Model resolution ──────────────────────────────────────────────────────────
|
||||
|
||||
def _resolve_model(registry: dict, model_id: str) -> dict | None:
|
||||
"""Resolve a model_id to its full config dict, or None if not found."""
|
||||
builtins = _builtins()
|
||||
|
||||
# Built-in IDs take priority over user-defined entries with the same ID
|
||||
if model_id in builtins:
|
||||
return dict(builtins[model_id])
|
||||
|
||||
model = next((m for m in registry.get("models", []) if m["id"] == model_id), None)
|
||||
if not model:
|
||||
return None
|
||||
|
||||
if model.get("type") == "local_openai":
|
||||
host_id = model.get("host_id")
|
||||
host = next((h for h in registry.get("hosts", []) if h["id"] == host_id), None)
|
||||
if not host:
|
||||
logger.warning("model %s references missing host_id %s", model_id, host_id)
|
||||
return None
|
||||
return {**model, "api_url": host.get("api_url", ""), "api_key": host.get("api_key", "")}
|
||||
|
||||
return dict(model)
|
||||
|
||||
|
||||
def get_model_for_role(username: str, role: str) -> dict | None:
|
||||
"""
|
||||
Return the resolved model config for the given role.
|
||||
|
||||
Resolution order:
|
||||
1. User registry: roles[role].primary → backup_1 → ... → backup_4
|
||||
2. .env: ROLE_<ROLE> = builtin model ID
|
||||
3. Hardcoded last-resort default per role
|
||||
4. claude_cli (absolute fallback)
|
||||
"""
|
||||
registry = _load(username)
|
||||
role_cfg = registry.get("roles", {}).get(role, {})
|
||||
|
||||
for key in PRIORITY_KEYS:
|
||||
model_id = role_cfg.get(key)
|
||||
if not model_id:
|
||||
continue
|
||||
resolved = _resolve_model(registry, model_id)
|
||||
if resolved:
|
||||
return resolved
|
||||
logger.debug("role %s.%s = %s but model not found", role, key, model_id)
|
||||
|
||||
# .env default
|
||||
env_type = settings.get_role_default(role)
|
||||
builtins = _builtins()
|
||||
if env_type and env_type in builtins:
|
||||
return dict(builtins[env_type])
|
||||
|
||||
# Hardcoded last resort
|
||||
fallback_id = _ROLE_LAST_RESORT.get(role, "claude_cli")
|
||||
return dict(builtins.get(fallback_id, builtins["claude_cli"]))
|
||||
|
||||
|
||||
def get_best_local_model(username: str, role: str = "chat") -> dict | None:
|
||||
"""
|
||||
Return the best available local_openai model for the given role.
|
||||
Used when the user explicitly selects "local" backend in the UI.
|
||||
Tries the role's priority chain first, then any configured local model.
|
||||
"""
|
||||
registry = _load(username)
|
||||
role_cfg = registry.get("roles", {}).get(role, {})
|
||||
|
||||
for key in PRIORITY_KEYS:
|
||||
model_id = role_cfg.get(key)
|
||||
if not model_id:
|
||||
continue
|
||||
resolved = _resolve_model(registry, model_id)
|
||||
if resolved and resolved.get("type") == "local_openai":
|
||||
return resolved
|
||||
|
||||
# Fall back to first configured local model
|
||||
for model in registry.get("models", []):
|
||||
if model.get("type") == "local_openai":
|
||||
resolved = _resolve_model(registry, model["id"])
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ── Read API (for UI and callers) ─────────────────────────────────────────────
|
||||
|
||||
def get_registry(username: str) -> dict:
|
||||
"""Return the full registry (with built-in models injected for display)."""
|
||||
return _load(username)
|
||||
|
||||
|
||||
def get_all_models(username: str) -> list[dict]:
|
||||
"""Return all user-defined models (resolved — hosts merged in)."""
|
||||
registry = _load(username)
|
||||
out = []
|
||||
for m in registry.get("models", []):
|
||||
resolved = _resolve_model(registry, m["id"])
|
||||
if resolved:
|
||||
out.append(resolved)
|
||||
return out
|
||||
|
||||
|
||||
def get_defined_roles(username: str) -> dict[str, dict]:
|
||||
"""Return the roles section of the registry, filling gaps with empty dicts."""
|
||||
registry = _load(username)
|
||||
roles = registry.get("roles", {})
|
||||
result = {}
|
||||
for role in settings.get_defined_roles():
|
||||
result[role] = roles.get(role, {})
|
||||
return result
|
||||
|
||||
|
||||
# ── Write API (CRUD) ──────────────────────────────────────────────────────────
|
||||
|
||||
def save_host(username: str, host_id: str | None,
|
||||
label: str, api_url: str, api_key: str) -> str:
|
||||
"""Create or update a host. Returns the host ID."""
|
||||
data = _load(username)
|
||||
|
||||
if host_id:
|
||||
for h in data["hosts"]:
|
||||
if h["id"] == host_id:
|
||||
h["label"] = label.strip()
|
||||
h["api_url"] = api_url.strip()
|
||||
if api_key.strip():
|
||||
h["api_key"] = api_key.strip()
|
||||
_save(username, data)
|
||||
return host_id
|
||||
host_id = None # not found — create new
|
||||
|
||||
host_id = secrets.token_hex(4)
|
||||
data["hosts"].append({
|
||||
"id": host_id,
|
||||
"label": label.strip(),
|
||||
"api_url": api_url.strip(),
|
||||
"api_key": api_key.strip(),
|
||||
})
|
||||
_save(username, data)
|
||||
return host_id
|
||||
|
||||
|
||||
def remove_host(username: str, host_id: str) -> bool:
|
||||
"""Remove a host and all models that reference it. Returns True if found."""
|
||||
data = _load(username)
|
||||
before = len(data["hosts"])
|
||||
data["hosts"] = [h for h in data["hosts"] if h["id"] != host_id]
|
||||
data["models"] = [m for m in data["models"] if m.get("host_id") != host_id]
|
||||
# Clear any role assignments that pointed to removed models
|
||||
removed_ids = {m["id"] for m in data["models"] if m.get("host_id") == host_id}
|
||||
for role_cfg in data.get("roles", {}).values():
|
||||
for key in PRIORITY_KEYS:
|
||||
if role_cfg.get(key) in removed_ids:
|
||||
role_cfg[key] = None
|
||||
_save(username, data)
|
||||
return len(data["hosts"]) < before
|
||||
|
||||
|
||||
def save_model(username: str, model_id: str | None, host_id: str,
|
||||
label: str, model_name: str, context_k: int = 0,
|
||||
tags: list[str] | None = None) -> str:
|
||||
"""Create or update a model entry. Returns the model ID."""
|
||||
data = _load(username)
|
||||
tags = tags or []
|
||||
|
||||
if model_id:
|
||||
for m in data["models"]:
|
||||
if m["id"] == model_id:
|
||||
m["host_id"] = host_id
|
||||
m["label"] = label.strip() or model_name.strip()
|
||||
m["model_name"] = model_name.strip()
|
||||
m["context_k"] = context_k
|
||||
m["tags"] = tags
|
||||
_save(username, data)
|
||||
return model_id
|
||||
model_id = None
|
||||
|
||||
model_id = secrets.token_hex(4)
|
||||
data["models"].append({
|
||||
"id": model_id,
|
||||
"type": "local_openai",
|
||||
"label": label.strip() or model_name.strip(),
|
||||
"model_name": model_name.strip(),
|
||||
"host_id": host_id,
|
||||
"context_k": context_k,
|
||||
"tags": tags,
|
||||
})
|
||||
_save(username, data)
|
||||
return model_id
|
||||
|
||||
|
||||
def remove_model(username: str, model_id: str) -> bool:
|
||||
"""Remove a model and clear any role assignments pointing to it."""
|
||||
data = _load(username)
|
||||
before = len(data["models"])
|
||||
data["models"] = [m for m in data["models"] if m["id"] != model_id]
|
||||
|
||||
for role_cfg in data.get("roles", {}).values():
|
||||
for key in PRIORITY_KEYS:
|
||||
if role_cfg.get(key) == model_id:
|
||||
role_cfg[key] = None
|
||||
|
||||
_save(username, data)
|
||||
return len(data["models"]) < before
|
||||
|
||||
|
||||
def set_role(username: str, role: str, priority: str, model_id: str | None) -> bool:
|
||||
"""
|
||||
Assign a model to a role priority slot.
|
||||
|
||||
priority must be one of: primary, backup_1, backup_2, backup_3, backup_4
|
||||
model_id None clears the slot.
|
||||
model_id "claude_cli" / "gemini_cli" / "gemini_api" are valid built-in IDs.
|
||||
Returns False if model_id is set but not found.
|
||||
"""
|
||||
if priority not in PRIORITY_KEYS:
|
||||
return False
|
||||
|
||||
data = _load(username)
|
||||
|
||||
if model_id and model_id not in _builtins():
|
||||
if not any(m["id"] == model_id for m in data["models"]):
|
||||
return False
|
||||
|
||||
roles = data.setdefault("roles", {})
|
||||
if role not in roles:
|
||||
roles[role] = {}
|
||||
roles[role][priority] = model_id or None
|
||||
|
||||
_save(username, data)
|
||||
return True
|
||||
|
||||
|
||||
def fetch_models_from_host(api_url: str, api_key: str) -> list[str]:
|
||||
"""Synchronously fetch the model list from an OpenAI-compatible host."""
|
||||
import httpx
|
||||
url = api_url.rstrip("/") + "/api/models"
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
resp = httpx.get(url, headers=headers, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
models = data.get("data", [])
|
||||
return sorted(m.get("id", m.get("name", "")) for m in models if m.get("id") or m.get("name"))
|
||||
@@ -72,21 +72,33 @@ def _gemini_status() -> dict:
|
||||
return {"ok": False, "error": str(e), "warning": True, "authenticated": False}
|
||||
|
||||
|
||||
async def _local_status() -> dict:
|
||||
if not settings.local_api_url:
|
||||
async def _local_status(username: str = "scott") -> dict:
|
||||
"""Check reachability of the user's configured local model host."""
|
||||
import model_registry
|
||||
cfg = model_registry.get_best_local_model(username)
|
||||
if not cfg:
|
||||
return {"configured": False}
|
||||
api_url = cfg.get("api_url", "")
|
||||
if not api_url:
|
||||
return {"configured": False}
|
||||
try:
|
||||
import httpx
|
||||
url = settings.local_api_url.rstrip("/") + "/api/models"
|
||||
url = api_url.rstrip("/") + "/api/models"
|
||||
headers = {}
|
||||
if settings.local_api_key:
|
||||
headers["Authorization"] = f"Bearer {settings.local_api_key}"
|
||||
api_key = cfg.get("api_key", "")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
reachable = resp.status_code < 400
|
||||
return {"configured": True, "reachable": reachable, "model": settings.local_model}
|
||||
return {
|
||||
"configured": True,
|
||||
"reachable": reachable,
|
||||
"model": cfg.get("model_name", ""),
|
||||
"label": cfg.get("label", ""),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"configured": True, "reachable": False, "error": str(e), "model": settings.local_model}
|
||||
return {"configured": True, "reachable": False, "error": str(e), "model": cfg.get("model_name", "")}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
|
||||
@@ -11,7 +11,7 @@ from session_store import load as load_session, save as save_session, list_all,
|
||||
from config import settings
|
||||
from persona import set_context, validate as validate_persona
|
||||
from auth_utils import COOKIE_NAME, decode_token
|
||||
import user_settings
|
||||
import model_registry
|
||||
import event_bus
|
||||
|
||||
|
||||
@@ -138,15 +138,15 @@ _BACKEND_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"}
|
||||
|
||||
|
||||
def _local_model_info(request: Request) -> dict | None:
|
||||
"""Return active local model {label, model_name} for the session user, or None."""
|
||||
"""Return the best local model {label, model_name} for the session user, or None."""
|
||||
try:
|
||||
token = request.cookies.get(COOKIE_NAME)
|
||||
username = decode_token(token) if token else None
|
||||
if not username:
|
||||
return None
|
||||
cfg = user_settings.get_active_local_model(username)
|
||||
cfg = model_registry.get_best_local_model(username, "chat")
|
||||
if cfg:
|
||||
return {"label": cfg["label"], "model_name": cfg["model_name"]}
|
||||
return {"label": cfg.get("label", ""), "model_name": cfg.get("model_name", "")}
|
||||
except (jwt.InvalidTokenError, Exception):
|
||||
pass
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user