Switches the orchestrator's final response from a fire-and-wait model to a
live SSE stream so text appears token-by-token as the model generates it.
- llm_client: complete() gains token_sink param; anthropic_api backend uses
client.messages.stream(); local backend uses httpx SSE streaming; non-streaming
backends (claude_cli, gemini_cli) emit the full text as one chunk
- orchestrator_engine + openai_orchestrator: token_sink threaded through run(),
_run_from_contents(), _claude_handoff(), and _run_from_messages()
- routers/orchestrator: each job gets an asyncio.Queue; _on_progress and
_token_sink write progress/token events to it; _finalize_job emits done,
error handler emits error, confirmation gate emits confirm; new GET
/orchestrate/{job_id}/stream SSE endpoint with 20s keepalive
- app.js: _doOrchestrate switches from 2s poll loop to EventSource; thinking
bubble converts to a streaming message on first token; auto-scroll while
streaming; confirm/error/done events handled; finalization unchanged
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
513 lines
18 KiB
Python
513 lines
18 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
from config import settings
|
|
import event_bus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Track active Gemini process group IDs so we can kill them on shutdown
|
|
_active_pgroups: set[int] = set()
|
|
|
|
|
|
def _register_pgroup(pid: int) -> None:
|
|
_active_pgroups.add(pid)
|
|
|
|
|
|
def _unregister_pgroup(pid: int) -> None:
|
|
_active_pgroups.discard(pid)
|
|
|
|
|
|
async def cleanup() -> None:
|
|
"""Kill any lingering Gemini process groups. Call from lifespan shutdown."""
|
|
for pid in list(_active_pgroups):
|
|
try:
|
|
os.killpg(pid, signal.SIGKILL)
|
|
logger.info("Shutdown: killed Gemini process group %d", pid)
|
|
except ProcessLookupError:
|
|
pass
|
|
_active_pgroups.clear()
|
|
|
|
|
|
# Map from registry model type → dispatch function key
|
|
_TYPE_TO_BACKEND = {
|
|
"claude_cli": "claude",
|
|
"gemini_cli": "gemini",
|
|
"gemini_api": "gemini", # gemini_api falls back to CLI in this context
|
|
"local_openai": "local",
|
|
"anthropic_api": "anthropic_api",
|
|
}
|
|
|
|
# Explicit UI toggle values (kept for backward compat)
|
|
_EXPLICIT_BACKENDS = ("claude", "gemini", "local")
|
|
_FALLBACK = {"claude": "gemini", "gemini": "claude", "local": "claude", "anthropic_api": "claude"}
|
|
|
|
|
|
async def complete(
|
|
system_prompt: str,
|
|
messages: list[dict],
|
|
model: str | None = None,
|
|
role: str = "chat",
|
|
slot: str | None = None,
|
|
max_tokens: int = 2048,
|
|
attachment: dict | None = None,
|
|
token_sink=None, # async (str) -> None; if set, stream tokens as they arrive
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Returns (response_text, actual_backend_used).
|
|
|
|
slot: Phase 3 — specific role slot ("primary" | "backup_1" | "backup_2").
|
|
Resolves that exact slot, no fallback chain. Takes priority over model.
|
|
model: legacy backend override ("claude" | "gemini" | "local") from old toggle.
|
|
None = resolve via model registry for the given role.
|
|
role: registry role used for slot/auto routing (default: "chat").
|
|
"""
|
|
import model_registry as _reg
|
|
from persona import _user
|
|
|
|
username = _user.get()
|
|
resolved_cfg: dict | None = None
|
|
|
|
if slot is not None:
|
|
# Phase 3: explicit slot selection — no fallback within the role
|
|
resolved_cfg = _reg.get_model_for_slot(username, role, slot)
|
|
if resolved_cfg:
|
|
primary = _TYPE_TO_BACKEND.get(resolved_cfg["type"], "claude")
|
|
else:
|
|
# Slot not configured — fall through to auto routing
|
|
slot = None
|
|
|
|
if slot is None:
|
|
if model in _EXPLICIT_BACKENDS:
|
|
# Legacy: explicit backend override from old UI toggle
|
|
if model == "local":
|
|
resolved_cfg = _reg.get_best_local_model(username, role)
|
|
if not resolved_cfg:
|
|
raise RuntimeError("No local model configured — add one at /settings/models")
|
|
primary = model
|
|
else:
|
|
# Auto: role-based routing via model registry
|
|
resolved = _reg.get_model_for_role(username, role)
|
|
if resolved:
|
|
resolved_cfg = resolved
|
|
primary = _TYPE_TO_BACKEND.get(resolved["type"], "claude")
|
|
else:
|
|
primary = settings.primary_backend
|
|
|
|
fallback = _FALLBACK.get(primary, "claude")
|
|
|
|
try:
|
|
response = await _dispatch(primary, system_prompt, messages, resolved_cfg,
|
|
attachment=attachment, token_sink=token_sink)
|
|
return response, primary
|
|
except Exception as e:
|
|
err_str = str(e)
|
|
if primary == "claude" and any(k in err_str for k in ("401", "authenticate", "expired", "OAuth")):
|
|
await event_bus.publish({"type": "claude_auth_expired"})
|
|
# Surface errors when a model is explicitly configured or a specific slot was pinned.
|
|
if resolved_cfg is not None:
|
|
logger.error("%s failed (no fallback — model explicitly configured): %s", primary, e)
|
|
raise
|
|
logger.warning("%s failed (%s) — falling back to %s", primary, e, fallback)
|
|
response = await _dispatch(fallback, system_prompt, messages, None, token_sink=token_sink)
|
|
return response, fallback
|
|
|
|
|
|
async def _dispatch(
|
|
backend: str,
|
|
system_prompt: str,
|
|
messages: list[dict],
|
|
model_cfg: dict | None,
|
|
attachment: dict | None = None,
|
|
token_sink=None,
|
|
) -> str:
|
|
if backend == "gemini":
|
|
text = await _gemini(system_prompt, messages)
|
|
elif backend == "local":
|
|
if token_sink:
|
|
return await _local_streaming(token_sink, system_prompt, messages, model_cfg)
|
|
text = await _local(system_prompt, messages, model_cfg, attachment=attachment)
|
|
elif backend == "anthropic_api":
|
|
if token_sink:
|
|
return await _anthropic_api_streaming(token_sink, system_prompt, messages, model_cfg)
|
|
text = await _anthropic_api(system_prompt, messages, model_cfg)
|
|
else:
|
|
text = await _claude(system_prompt, messages, model_cfg)
|
|
# For non-streaming backends when token_sink is provided, emit the full text as one chunk.
|
|
if token_sink and text:
|
|
await token_sink(text)
|
|
return text
|
|
|
|
|
|
def _fresh_claude_token() -> str | None:
|
|
"""Read the current OAuth access token from the Claude credentials file.
|
|
|
|
The token in the systemd .env goes stale (it rotates on each login).
|
|
Reading directly from ~/.claude/.credentials.json always gets the latest.
|
|
"""
|
|
import json as _json
|
|
creds_path = os.path.expanduser("~/.claude/.credentials.json")
|
|
try:
|
|
with open(creds_path) as f:
|
|
data = _json.load(f)
|
|
return data["claudeAiOauth"]["accessToken"]
|
|
except Exception as e:
|
|
logger.debug("Could not read Claude credentials file: %s", e)
|
|
return None
|
|
|
|
|
|
async def _claude(system_prompt: str, messages: list[dict], model_cfg: dict | None) -> str:
|
|
model_name = (model_cfg or {}).get("model_name") if model_cfg else None
|
|
cmd = [
|
|
"claude", "--print",
|
|
"--no-session-persistence",
|
|
"--output-format", "text",
|
|
]
|
|
# Only pass --model if it's a real model name (not a backend type string)
|
|
if model_name and model_name not in ("claude", "gemini", "local", ""):
|
|
cmd.extend(["--model", model_name])
|
|
if system_prompt:
|
|
cmd.extend(["--system-prompt", system_prompt])
|
|
cmd.append(_build_conversation(messages))
|
|
|
|
# Always use the freshest token from the credentials file so the systemd
|
|
# service doesn't break when the env-var token rotates after a login.
|
|
env = os.environ.copy()
|
|
token = _fresh_claude_token()
|
|
if token:
|
|
env["CLAUDE_CODE_OAUTH_TOKEN"] = token
|
|
env.pop("ANTHROPIC_API_KEY", None) # never let a stale API key override OAuth
|
|
|
|
return await _run(cmd, timeout=settings.timeout_claude, env=env)
|
|
|
|
|
|
async def _local(
|
|
system_prompt: str,
|
|
messages: list[dict],
|
|
model_cfg: dict | None = None,
|
|
attachment: dict | None = None,
|
|
) -> str:
|
|
"""OpenAI-compatible backend — Open WebUI / Ollama.
|
|
|
|
model_cfg is pre-resolved by complete() via model_registry.
|
|
Falls back to registry lookup if not provided.
|
|
attachment: optional image dict {filename, mime_type, data} for vision calls.
|
|
"""
|
|
import httpx
|
|
|
|
cfg = model_cfg
|
|
if not cfg:
|
|
# Fallback: resolve directly from registry
|
|
import model_registry as _reg
|
|
from persona import _user
|
|
cfg = _reg.get_best_local_model(_user.get())
|
|
if not cfg:
|
|
raise RuntimeError("No local model configured — add one at /settings/models")
|
|
|
|
api_url = cfg["api_url"]
|
|
api_key = cfg["api_key"]
|
|
model = cfg["model_name"]
|
|
|
|
if not api_url:
|
|
raise RuntimeError("local_api_url not configured — set LOCAL_API_URL in .env or add a host at /settings/models")
|
|
if not model:
|
|
raise RuntimeError("local_model not configured — add a model at /settings/models")
|
|
|
|
host_type = cfg.get("host_type", "openwebui")
|
|
# "openwebui" uses Open WebUI/Ollama path layout; "openai" uses standard OpenAI layout
|
|
chat_path = "/chat/completions" if host_type == "openai" else "/api/chat/completions"
|
|
logger.info("local backend (%s): %s @ %s", host_type, model, api_url)
|
|
|
|
msgs: list[dict] = []
|
|
if system_prompt:
|
|
msgs.append({"role": "system", "content": system_prompt})
|
|
|
|
# Build message list; inject image into the last user message when present.
|
|
for i, m in enumerate(messages):
|
|
is_last = (i == len(messages) - 1)
|
|
if is_last and m["role"] == "user" and attachment:
|
|
content: list[dict] = [{"type": "text", "text": m["content"]}]
|
|
content.append({
|
|
"type": "image_url",
|
|
"image_url": {"url": attachment["data"]},
|
|
})
|
|
msgs.append({"role": "user", "content": content})
|
|
else:
|
|
# Strip non-standard metadata fields before sending to the API
|
|
msgs.append({"role": m["role"], "content": m["content"]})
|
|
|
|
url = api_url.rstrip("/") + chat_path
|
|
headers: dict[str, str] = {}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
payload = {"model": model, "messages": msgs}
|
|
|
|
async with httpx.AsyncClient(timeout=settings.timeout_local) as client:
|
|
resp = await client.post(url, json=payload, headers=headers)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
text = data["choices"][0]["message"]["content"]
|
|
if not text or not text.strip():
|
|
raise RuntimeError("Local model returned an empty response")
|
|
|
|
usage = data.get("usage") or {}
|
|
if usage.get("prompt_tokens") is not None:
|
|
import usage_tracker
|
|
from persona import _user
|
|
asyncio.create_task(usage_tracker.record(
|
|
username=_user.get(),
|
|
backend="local",
|
|
model_name=model,
|
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
completion_tokens=usage.get("completion_tokens", 0),
|
|
))
|
|
|
|
return text.strip()
|
|
|
|
|
|
async def _anthropic_api(system_prompt: str, messages: list[dict], model_cfg: dict | None) -> str:
|
|
"""Direct Anthropic API backend using the anthropic SDK."""
|
|
try:
|
|
import anthropic
|
|
except ImportError:
|
|
raise RuntimeError("anthropic SDK not installed — run: pip install 'anthropic>=0.40.0'")
|
|
|
|
cfg = model_cfg or {}
|
|
api_key = cfg.get("api_key", "")
|
|
model_name = cfg.get("model_name") or settings.default_model
|
|
|
|
if not api_key:
|
|
raise RuntimeError("No Anthropic API key — add one at /settings/models")
|
|
|
|
client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
|
|
msgs = [{"role": m["role"], "content": m["content"]} for m in messages]
|
|
kwargs: dict = {
|
|
"model": model_name,
|
|
"max_tokens": 4096,
|
|
"messages": msgs,
|
|
}
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
resp = await client.messages.create(**kwargs)
|
|
|
|
text = resp.content[0].text if resp.content else ""
|
|
if not text.strip():
|
|
raise RuntimeError("Anthropic API returned an empty response")
|
|
|
|
if resp.usage:
|
|
import usage_tracker
|
|
from persona import _user
|
|
asyncio.create_task(usage_tracker.record(
|
|
username=_user.get(),
|
|
backend="anthropic_api",
|
|
model_name=model_name,
|
|
prompt_tokens=resp.usage.input_tokens,
|
|
completion_tokens=resp.usage.output_tokens,
|
|
))
|
|
|
|
return text.strip()
|
|
|
|
|
|
async def _anthropic_api_streaming(
|
|
token_sink, system_prompt: str, messages: list[dict], model_cfg: dict | None
|
|
) -> str:
|
|
try:
|
|
import anthropic
|
|
except ImportError:
|
|
raise RuntimeError("anthropic SDK not installed — run: pip install 'anthropic>=0.40.0'")
|
|
|
|
cfg = model_cfg or {}
|
|
api_key = cfg.get("api_key", "")
|
|
model_name = cfg.get("model_name") or settings.default_model
|
|
|
|
if not api_key:
|
|
raise RuntimeError("No Anthropic API key — add one at /settings/models")
|
|
|
|
client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
msgs = [{"role": m["role"], "content": m["content"]} for m in messages]
|
|
kwargs: dict = {"model": model_name, "max_tokens": 4096, "messages": msgs}
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
full_text = ""
|
|
async with client.messages.stream(**kwargs) as stream:
|
|
async for chunk in stream.text_stream:
|
|
await token_sink(chunk)
|
|
full_text += chunk
|
|
|
|
final_msg = await stream.get_final_message()
|
|
if final_msg.usage:
|
|
import usage_tracker
|
|
from persona import _user
|
|
asyncio.create_task(usage_tracker.record(
|
|
username=_user.get(),
|
|
backend="anthropic_api",
|
|
model_name=model_name,
|
|
prompt_tokens=final_msg.usage.input_tokens,
|
|
completion_tokens=final_msg.usage.output_tokens,
|
|
))
|
|
|
|
return full_text.strip()
|
|
|
|
|
|
async def _local_streaming(
|
|
token_sink, system_prompt: str, messages: list[dict], model_cfg: dict | None
|
|
) -> str:
|
|
import httpx
|
|
import json as _json
|
|
|
|
cfg = model_cfg or {}
|
|
api_url = cfg.get("api_url", "")
|
|
api_key = cfg.get("api_key", "")
|
|
model = cfg.get("model_name", "")
|
|
host_type = cfg.get("host_type", "openwebui")
|
|
|
|
if not api_url:
|
|
raise RuntimeError("local_api_url not configured")
|
|
if not model:
|
|
raise RuntimeError("local_model not configured")
|
|
|
|
chat_path = "/chat/completions" if host_type == "openai" else "/api/chat/completions"
|
|
url = api_url.rstrip("/") + chat_path
|
|
headers: dict[str, str] = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
|
|
|
msgs: list[dict] = []
|
|
if system_prompt:
|
|
msgs.append({"role": "system", "content": system_prompt})
|
|
for m in messages:
|
|
msgs.append({"role": m["role"], "content": m["content"]})
|
|
|
|
payload = {"model": model, "messages": msgs, "stream": True}
|
|
full_text = ""
|
|
|
|
async with httpx.AsyncClient(timeout=settings.timeout_local) as client:
|
|
async with client.stream("POST", url, json=payload, headers=headers) as resp:
|
|
resp.raise_for_status()
|
|
async for line in resp.aiter_lines():
|
|
if not line or not line.startswith("data: "):
|
|
continue
|
|
data_str = line[6:].strip()
|
|
if data_str == "[DONE]":
|
|
break
|
|
try:
|
|
chunk = _json.loads(data_str)
|
|
delta = (chunk["choices"][0]["delta"].get("content") or "")
|
|
if delta:
|
|
await token_sink(delta)
|
|
full_text += delta
|
|
except Exception:
|
|
pass
|
|
|
|
return full_text.strip()
|
|
|
|
|
|
async def _gemini(system_prompt: str, messages: list[dict]) -> str:
|
|
# Gemini CLI spawns MCP child processes that keep stdout pipes open after responding.
|
|
# start_new_session=True puts the whole tree in its own process group so
|
|
# os.killpg kills everything at once on timeout.
|
|
cmd = [
|
|
"gemini",
|
|
"--output-format", "text",
|
|
"--extensions", "", # disable all extensions — prevents MCP child processes
|
|
"-p", _build_prompt(system_prompt, messages),
|
|
]
|
|
|
|
try:
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
start_new_session=True,
|
|
)
|
|
except FileNotFoundError:
|
|
raise RuntimeError("gemini not found in PATH")
|
|
|
|
_register_pgroup(proc.pid)
|
|
timeout = settings.timeout_gemini
|
|
try:
|
|
stdout_bytes, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
|
raw = stdout_bytes.decode()
|
|
except asyncio.TimeoutError:
|
|
try:
|
|
os.killpg(proc.pid, signal.SIGKILL)
|
|
except ProcessLookupError:
|
|
pass
|
|
raise RuntimeError(f"Gemini timed out after {timeout}s")
|
|
except asyncio.CancelledError:
|
|
try:
|
|
os.killpg(proc.pid, signal.SIGKILL)
|
|
except ProcessLookupError:
|
|
pass
|
|
raise
|
|
finally:
|
|
_unregister_pgroup(proc.pid)
|
|
|
|
clean = _clean_gemini_output(raw)
|
|
if not clean:
|
|
raise RuntimeError("Gemini returned an empty response")
|
|
return clean
|
|
|
|
|
|
# Lines Gemini CLI writes to stdout that are not part of the actual response
|
|
_GEMINI_NOISE = (
|
|
"Loaded cached credentials",
|
|
"Loading extension:",
|
|
"Server '",
|
|
"Listening for",
|
|
"Model is overloaded",
|
|
"High demand",
|
|
"Retrying",
|
|
"retrying",
|
|
"429",
|
|
"quota",
|
|
)
|
|
|
|
|
|
def _clean_gemini_output(text: str) -> str:
|
|
lines = [
|
|
line for line in text.splitlines()
|
|
if not any(line.strip().startswith(p) for p in _GEMINI_NOISE)
|
|
]
|
|
return "\n".join(lines).strip()
|
|
|
|
|
|
async def _run(cmd: list[str], timeout: int = 60, env: dict | None = None) -> str:
|
|
loop = asyncio.get_running_loop()
|
|
result = await loop.run_in_executor(
|
|
None,
|
|
lambda: subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=env),
|
|
)
|
|
if result.returncode != 0:
|
|
detail = result.stderr.strip() or result.stdout.strip() or f"exit code {result.returncode}"
|
|
raise RuntimeError(f"{cmd[0]} failed: {detail}")
|
|
return result.stdout.strip()
|
|
|
|
|
|
def _build_conversation(messages: list[dict]) -> str:
|
|
"""Conversation only — used for Claude (system prompt passed separately)."""
|
|
parts = []
|
|
prior = messages[:-1]
|
|
if prior:
|
|
history_lines = []
|
|
for msg in prior:
|
|
label = settings.user_name if msg["role"] == "user" else settings.agent_name
|
|
history_lines.append(f"{label}: {msg['content']}")
|
|
parts.append("<conversation>\n" + "\n\n".join(history_lines) + "\n</conversation>")
|
|
parts.append(messages[-1]["content"] if messages else "")
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
def _build_prompt(system_prompt: str, messages: list[dict]) -> str:
|
|
"""Full prompt with system context embedded — used for Gemini."""
|
|
parts = []
|
|
if system_prompt:
|
|
parts.append(f"<system>\n{system_prompt}\n</system>")
|
|
parts.append(_build_conversation(messages))
|
|
return "\n\n".join(parts)
|