import asyncio import logging import os import signal import subprocess from config import settings import event_bus logger = logging.getLogger(__name__) # Track active Gemini process group IDs so we can kill them on shutdown _active_pgroups: set[int] = set() def _register_pgroup(pid: int) -> None: _active_pgroups.add(pid) def _unregister_pgroup(pid: int) -> None: _active_pgroups.discard(pid) async def cleanup() -> None: """Kill any lingering Gemini process groups. Call from lifespan shutdown.""" for pid in list(_active_pgroups): try: os.killpg(pid, signal.SIGKILL) logger.info("Shutdown: killed Gemini process group %d", pid) except ProcessLookupError: pass _active_pgroups.clear() # Map from registry model type → dispatch function key _TYPE_TO_BACKEND = { "claude_cli": "claude", "gemini_cli": "gemini", "gemini_api": "gemini", # gemini_api falls back to CLI in this context "local_openai": "local", } # Explicit UI toggle values (kept for backward compat) _EXPLICIT_BACKENDS = ("claude", "gemini", "local") _FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude"} async def complete( system_prompt: str, messages: list[dict], model: str | None = None, role: str = "chat", max_tokens: int = 2048, ) -> tuple[str, str]: """ Returns (response_text, actual_backend_used). model: explicit backend override ("claude" | "gemini" | "local") from UI toggle. None = resolve via model registry for the given role. role: registry role used when model is None (default: "chat"). """ import model_registry as _reg from persona import _user username = _user.get() resolved_cfg: dict | None = None if model in _EXPLICIT_BACKENDS: # User explicitly selected a backend in the UI if model == "local": resolved_cfg = _reg.get_best_local_model(username, role) if not resolved_cfg: raise RuntimeError("No local model configured — add one at /settings/models") primary = model else: # Role-based routing via model registry resolved = _reg.get_model_for_role(username, role) if resolved: resolved_cfg = resolved primary = _TYPE_TO_BACKEND.get(resolved["type"], "claude") else: primary = settings.primary_backend fallback = _FALLBACK.get(primary, "claude") try: response = await _dispatch(primary, system_prompt, messages, resolved_cfg) return response, primary except Exception as e: err_str = str(e) logger.warning("%s failed (%s) — falling back to %s", primary, e, fallback) if primary == "claude" and any(k in err_str for k in ("401", "authenticate", "expired", "OAuth")): await event_bus.publish({"type": "claude_auth_expired"}) response = await _dispatch(fallback, system_prompt, messages, None) return response, fallback async def _dispatch( backend: str, system_prompt: str, messages: list[dict], model_cfg: dict | None, ) -> str: if backend == "gemini": return await _gemini(system_prompt, messages) if backend == "local": return await _local(system_prompt, messages, model_cfg) return await _claude(system_prompt, messages, model_cfg) def _fresh_claude_token() -> str | None: """Read the current OAuth access token from the Claude credentials file. The token in the systemd .env goes stale (it rotates on each login). Reading directly from ~/.claude/.credentials.json always gets the latest. """ import json as _json creds_path = os.path.expanduser("~/.claude/.credentials.json") try: with open(creds_path) as f: data = _json.load(f) return data["claudeAiOauth"]["accessToken"] except Exception as e: logger.debug("Could not read Claude credentials file: %s", e) return None async def _claude(system_prompt: str, messages: list[dict], model_cfg: dict | None) -> str: model_name = (model_cfg or {}).get("model_name") if model_cfg else None cmd = [ "claude", "--print", "--no-session-persistence", "--output-format", "text", ] # Only pass --model if it's a real model name (not a backend type string) if model_name and model_name not in ("claude", "gemini", "local", ""): cmd.extend(["--model", model_name]) if system_prompt: cmd.extend(["--system-prompt", system_prompt]) cmd.append(_build_conversation(messages)) # Always use the freshest token from the credentials file so the systemd # service doesn't break when the env-var token rotates after a login. env = os.environ.copy() token = _fresh_claude_token() if token: env["CLAUDE_CODE_OAUTH_TOKEN"] = token env.pop("ANTHROPIC_API_KEY", None) # never let a stale API key override OAuth return await _run(cmd, timeout=settings.timeout_claude, env=env) async def _local(system_prompt: str, messages: list[dict], model_cfg: dict | None = None) -> str: """OpenAI-compatible backend — Open WebUI / Ollama. model_cfg is pre-resolved by complete() via model_registry. Falls back to registry lookup if not provided. """ import httpx cfg = model_cfg if not cfg: # Fallback: resolve directly from registry import model_registry as _reg from persona import _user cfg = _reg.get_best_local_model(_user.get()) if not cfg: raise RuntimeError("No local model configured — add one at /settings/models") api_url = cfg["api_url"] api_key = cfg["api_key"] model = cfg["model_name"] if not api_url: raise RuntimeError("local_api_url not configured — set LOCAL_API_URL in .env or add a host at /settings/local") if not model: raise RuntimeError("local_model not configured — add a model at /settings/local") host_type = cfg.get("host_type", "openwebui") # "openwebui" uses Open WebUI/Ollama path layout; "openai" uses standard OpenAI layout chat_path = "/chat/completions" if host_type == "openai" else "/api/chat/completions" logger.info("local backend (%s): %s @ %s", host_type, model, api_url) msgs: list[dict] = [] if system_prompt: msgs.append({"role": "system", "content": system_prompt}) msgs.extend(messages) url = api_url.rstrip("/") + chat_path headers: dict[str, str] = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" payload = {"model": model, "messages": msgs} async with httpx.AsyncClient(timeout=settings.timeout_local) as client: resp = await client.post(url, json=payload, headers=headers) resp.raise_for_status() data = resp.json() text = data["choices"][0]["message"]["content"] if not text or not text.strip(): raise RuntimeError("Local model returned an empty response") return text.strip() async def _gemini(system_prompt: str, messages: list[dict]) -> str: # Gemini CLI spawns MCP child processes that keep stdout pipes open after responding. # start_new_session=True puts the whole tree in its own process group so # os.killpg kills everything at once on timeout. cmd = [ "gemini", "--output-format", "text", "--extensions", "", # disable all extensions — prevents MCP child processes "-p", _build_prompt(system_prompt, messages), ] try: proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, start_new_session=True, ) except FileNotFoundError: raise RuntimeError("gemini not found in PATH") _register_pgroup(proc.pid) timeout = settings.timeout_gemini try: stdout_bytes, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout) raw = stdout_bytes.decode() except asyncio.TimeoutError: try: os.killpg(proc.pid, signal.SIGKILL) except ProcessLookupError: pass raise RuntimeError(f"Gemini timed out after {timeout}s") except asyncio.CancelledError: try: os.killpg(proc.pid, signal.SIGKILL) except ProcessLookupError: pass raise finally: _unregister_pgroup(proc.pid) clean = _clean_gemini_output(raw) if not clean: raise RuntimeError("Gemini returned an empty response") return clean # Lines Gemini CLI writes to stdout that are not part of the actual response _GEMINI_NOISE = ( "Loaded cached credentials", "Loading extension:", "Server '", "Listening for", "Model is overloaded", "High demand", "Retrying", "retrying", "429", "quota", ) def _clean_gemini_output(text: str) -> str: lines = [ line for line in text.splitlines() if not any(line.strip().startswith(p) for p in _GEMINI_NOISE) ] return "\n".join(lines).strip() async def _run(cmd: list[str], timeout: int = 60, env: dict | None = None) -> str: loop = asyncio.get_running_loop() result = await loop.run_in_executor( None, lambda: subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=env), ) if result.returncode != 0: detail = result.stderr.strip() or result.stdout.strip() or f"exit code {result.returncode}" raise RuntimeError(f"{cmd[0]} failed: {detail}") return result.stdout.strip() def _build_conversation(messages: list[dict]) -> str: """Conversation only — used for Claude (system prompt passed separately).""" parts = [] prior = messages[:-1] if prior: history_lines = [] for msg in prior: label = settings.user_name if msg["role"] == "user" else settings.agent_name history_lines.append(f"{label}: {msg['content']}") parts.append("\n" + "\n\n".join(history_lines) + "\n") parts.append(messages[-1]["content"] if messages else "") return "\n\n".join(parts) def _build_prompt(system_prompt: str, messages: list[dict]) -> str: """Full prompt with system context embedded — used for Gemini.""" parts = [] if system_prompt: parts.append(f"\n{system_prompt}\n") parts.append(_build_conversation(messages)) return "\n\n".join(parts)