Files
Cortex-Inara/cortex/routers/orchestrator.py
Scott Idem 9cb2b0d9a5 feat: token streaming for orchestrator final response
Switches the orchestrator's final response from a fire-and-wait model to a
live SSE stream so text appears token-by-token as the model generates it.

- llm_client: complete() gains token_sink param; anthropic_api backend uses
  client.messages.stream(); local backend uses httpx SSE streaming; non-streaming
  backends (claude_cli, gemini_cli) emit the full text as one chunk
- orchestrator_engine + openai_orchestrator: token_sink threaded through run(),
  _run_from_contents(), _claude_handoff(), and _run_from_messages()
- routers/orchestrator: each job gets an asyncio.Queue; _on_progress and
  _token_sink write progress/token events to it; _finalize_job emits done,
  error handler emits error, confirmation gate emits confirm; new GET
  /orchestrate/{job_id}/stream SSE endpoint with 20s keepalive
- app.js: _doOrchestrate switches from 2s poll loop to EventSource; thinking
  bubble converts to a streaming message on first token; auto-scroll while
  streaming; confirm/error/done events handled; finalization unchanged

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-16 23:22:50 -04:00

478 lines
18 KiB
Python

"""
Orchestrator router — POST /orchestrate, GET /orchestrate/{job_id}
Accepts a task description, runs it through the orchestrator engine
(Gemini tool loop → Claude response), and returns the result.
Designed to be triggered from:
- The Cortex web UI (future "Agent mode" toggle)
- Cron jobs: curl -X POST http://localhost:8000/orchestrate -d '{"task":"..."}'
- Webhooks: Gitea, Aether events, etc.
"""
import asyncio
import logging
import platform
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from auth_utils import get_user_gemini_key, get_user_role, get_tool_policy, get_risk_policy
from config import settings
from context_loader import load_context
from persona import set_context, validate as validate_persona
import model_registry
import orchestrator_engine
import openai_orchestrator
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/orchestrate", tags=["orchestrator"])
# ---------------------------------------------------------------------------
# In-memory job store
# ---------------------------------------------------------------------------
_jobs: dict[str, dict] = {}
_jobs_lock = asyncio.Lock()
# Checkpoints are stored separately — they hold Python objects (types.Content, etc.)
# that can't be included in the JSON-serializable job dict.
_checkpoints: dict[str, orchestrator_engine.OrchestrateCheckpoint] = {}
_checkpoints_lock = asyncio.Lock()
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class OrchestrateRequest(BaseModel):
task: str
session_id: str | None = None # include session history in context
tier: int | None = None # Inara context tier (default from settings)
respond_with_claude: bool = True # False = return Gemini summary only (faster, for cron)
include_long: bool = True
include_mid: bool = True
include_short: bool = True
user: str = "scott"
persona: str = "inara"
chat_role: str = "chat" # role used for the final response (decoupled from tool-loop model)
off_record: bool = False # skip session log; inject OTR mode line into system prompt
class OrchestrateResponse(BaseModel):
job_id: str
status: str # "queued" | "running" | "complete" | "error" | "awaiting_confirmation"
class JobStatusResponse(BaseModel):
job_id: str
status: str
task: str
created_at: str
completed_at: str | None = None
session_id: str | None = None
response: str | None = None
tool_calls: list[dict] | None = None
backend: str | None = None
backend_label: str | None = None
host: str | None = None
gemini_summary: str | None = None
error: str | None = None
pending_confirmation: dict | None = None # {tools: [{name, args}], message: str}
progress: str | None = None # live status text shown in UI during run
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("", response_model=OrchestrateResponse)
async def orchestrate(req: OrchestrateRequest) -> OrchestrateResponse:
"""Submit a task to the orchestrator. Returns a job_id to poll."""
try:
user, persona = validate_persona(req.user, req.persona)
set_context(user, persona)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
job: dict = {
"job_id": job_id,
"status": "queued",
"task": req.task,
"created_at": now,
"completed_at": None,
"session_id": None,
"response": None,
"tool_calls": None,
"backend": None,
"gemini_summary": None,
"error": None,
"pending_confirmation": None,
"progress": None,
"_user": user,
"_off_record": req.off_record,
"_event_queue": asyncio.Queue(),
}
async with _jobs_lock:
_jobs[job_id] = job
asyncio.create_task(_run_job(job_id, req, user))
logger.info("Orchestrator job queued: %s%.80s", job_id, req.task)
return OrchestrateResponse(job_id=job_id, status="queued")
@router.get("/{job_id}", response_model=JobStatusResponse)
async def job_status(job_id: str) -> JobStatusResponse:
"""Poll the status of an orchestrator job."""
async with _jobs_lock:
job = _jobs.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
return JobStatusResponse(**{k: v for k, v in job.items() if not k.startswith("_")})
@router.get("", response_model=list[JobStatusResponse])
async def list_jobs() -> list[JobStatusResponse]:
"""List all jobs (most recent first). Useful for debugging."""
async with _jobs_lock:
jobs = sorted(_jobs.values(), key=lambda j: j["created_at"], reverse=True)
return [JobStatusResponse(**{k: v for k, v in j.items() if not k.startswith("_")}) for j in jobs]
@router.get("/{job_id}/stream")
async def stream_job(job_id: str, request: Request) -> StreamingResponse:
"""SSE stream for a running job — emits progress, token, done, error, and confirm events."""
import json
async with _jobs_lock:
job = _jobs.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
# If already complete/error, emit a single done/error event immediately.
if job["status"] == "complete":
async def _done_now():
yield f"data: {json.dumps({'type': 'done', 'response': job['response'], 'session_id': job.get('session_id'), 'backend': job.get('backend', ''), 'backend_label': job.get('backend_label', ''), 'host': job.get('host', ''), 'tool_calls': job.get('tool_calls')})}\n\n"
return StreamingResponse(_done_now(), media_type="text/event-stream")
if job["status"] == "error":
async def _err_now():
yield f"data: {json.dumps({'type': 'error', 'message': job.get('error', 'Unknown error')})}\n\n"
return StreamingResponse(_err_now(), media_type="text/event-stream")
queue: asyncio.Queue = job["_event_queue"]
async def generate():
yield 'data: {"type":"connected"}\n\n'
while True:
if await request.is_disconnected():
break
try:
event = await asyncio.wait_for(queue.get(), timeout=20)
yield f"data: {json.dumps(event)}\n\n"
if event["type"] in ("done", "error"):
break
# For confirm events: keep listening — job will resume after user action.
except asyncio.TimeoutError:
yield 'data: {"type":"keepalive"}\n\n'
return StreamingResponse(generate(), media_type="text/event-stream")
@router.post("/{job_id}/confirm", response_model=OrchestrateResponse)
async def confirm_job(job_id: str) -> OrchestrateResponse:
"""Confirm a pending tool call — the blocked tool will execute and the job continues."""
async with _checkpoints_lock:
checkpoint = _checkpoints.pop(job_id, None)
if checkpoint is None:
raise HTTPException(status_code=404, detail="No pending confirmation for this job")
async with _jobs_lock:
job = _jobs.get(job_id)
if not job or job["status"] != "awaiting_confirmation":
raise HTTPException(status_code=409, detail="Job is not awaiting confirmation")
_jobs[job_id]["status"] = "running"
_jobs[job_id]["pending_confirmation"] = None
user = job.get("_user", "scott")
asyncio.create_task(_resume_job(job_id, checkpoint, confirmed=True, user=user))
logger.info("Orchestrator job %s confirmed — resuming", job_id)
return OrchestrateResponse(job_id=job_id, status="running")
@router.post("/{job_id}/deny", response_model=OrchestrateResponse)
async def deny_job(job_id: str) -> OrchestrateResponse:
"""Deny a pending tool call — the tool is skipped and the job produces a final response."""
async with _checkpoints_lock:
checkpoint = _checkpoints.pop(job_id, None)
if checkpoint is None:
raise HTTPException(status_code=404, detail="No pending confirmation for this job")
async with _jobs_lock:
job = _jobs.get(job_id)
if not job or job["status"] != "awaiting_confirmation":
raise HTTPException(status_code=409, detail="Job is not awaiting confirmation")
_jobs[job_id]["status"] = "running"
_jobs[job_id]["pending_confirmation"] = None
user = job.get("_user", "scott")
asyncio.create_task(_resume_job(job_id, checkpoint, confirmed=False, user=user))
logger.info("Orchestrator job %s denied — resuming with skip", job_id)
return OrchestrateResponse(job_id=job_id, status="running")
# ---------------------------------------------------------------------------
# Background runners
# ---------------------------------------------------------------------------
async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
"""Execute the orchestration job and update the job store."""
async with _jobs_lock:
_jobs[job_id]["status"] = "running"
async def _on_progress(msg: str) -> None:
async with _jobs_lock:
if job_id not in _jobs:
return
_jobs[job_id]["progress"] = msg
q = _jobs[job_id].get("_event_queue")
if q:
await q.put({"type": "progress", "text": msg})
async def _token_sink(text: str) -> None:
async with _jobs_lock:
q = _jobs.get(job_id, {}).get("_event_queue")
if q:
await q.put({"type": "token", "text": text})
try:
from session_store import load as load_session, save as save_session, generate_session_id
tier = req.tier or settings.default_tier
role_cfg = model_registry.get_role_config(user, req.chat_role)
system_prompt = load_context(
tier,
include_long=req.include_long,
include_mid=req.include_mid,
include_short=req.include_short,
role_append=role_cfg.get("system_append", ""),
inject_datetime=role_cfg.get("inject_datetime", True),
inject_mode=role_cfg.get("inject_mode", True),
mode="otr" if req.off_record else "chat",
)
session_id = req.session_id or generate_session_id()
history = load_session(session_id)
session_messages = history or None
orch_model = model_registry.get_model_for_role(user, "orchestrator")
user_role = get_user_role(user)
tool_list = role_cfg.get("tools")
policy = get_tool_policy(user)
confirm_allow = set(policy.get("allow", []))
confirm_deny = set(policy.get("deny", []))
max_risk, risk_wl, risk_bl = get_risk_policy(user)
if orch_model and orch_model.get("type") == "local_openai":
result = await openai_orchestrator.run(
task=req.task,
system_prompt=system_prompt,
session_messages=session_messages,
model_cfg=orch_model,
respond_with_final=req.respond_with_claude,
user_role=user_role,
tool_list=tool_list,
confirm_allow=confirm_allow,
confirm_deny=confirm_deny,
max_risk=max_risk,
risk_whitelist=risk_wl,
risk_blacklist=risk_bl,
on_progress=_on_progress,
token_sink=_token_sink,
)
else:
gemini_key = (
(orch_model.get("api_key") if orch_model else None)
or get_user_gemini_key(user)
)
result = await orchestrator_engine.run(
task=req.task,
system_prompt=system_prompt,
session_messages=session_messages,
respond_with_claude=req.respond_with_claude,
gemini_api_key=gemini_key,
model_name=orch_model.get("model_name") if orch_model else None,
response_role=req.chat_role,
user_role=user_role,
tool_list=tool_list,
confirm_allow=confirm_allow,
confirm_deny=confirm_deny,
max_rounds=orch_model.get("max_rounds") if orch_model else None,
max_risk=max_risk,
risk_whitelist=risk_wl,
risk_blacklist=risk_bl,
on_progress=_on_progress,
token_sink=_token_sink,
)
if result.checkpoint:
async with _checkpoints_lock:
_checkpoints[job_id] = result.checkpoint
async with _jobs_lock:
_jobs[job_id].update({
"status": "awaiting_confirmation",
"response": result.response,
"tool_calls": result.tool_calls,
"backend": result.backend,
"gemini_summary": result.gemini_summary,
"session_id": session_id,
"pending_confirmation": {
"tools": result.checkpoint.pending_tools,
"message": result.response,
},
})
q = _jobs[job_id].get("_event_queue")
logger.info("Orchestrator job %s awaiting confirmation — %d tool(s) blocked",
job_id, len(result.checkpoint.pending_tools))
if q:
await q.put({
"type": "confirm",
"tools": result.checkpoint.pending_tools,
"message": result.response,
})
return
await _finalize_job(job_id, result, session_id, req.task, history, off_record=req.off_record)
except Exception as e:
logger.exception("Orchestrator job failed: %s", job_id)
now = datetime.now(timezone.utc).isoformat()
async with _jobs_lock:
_jobs[job_id].update({
"status": "error",
"completed_at": now,
"error": str(e),
})
q = _jobs[job_id].get("_event_queue")
if q:
await q.put({"type": "error", "message": str(e)})
async def _resume_job(
job_id: str,
checkpoint: orchestrator_engine.OrchestrateCheckpoint,
confirmed: bool,
user: str,
) -> None:
"""Resume a job after the user confirms or denies a pending tool call."""
try:
if checkpoint.engine == "gemini":
result = await orchestrator_engine.resume(checkpoint, confirmed)
else:
result = await openai_orchestrator.resume(checkpoint, confirmed)
if result.checkpoint:
# Another confirmation needed (chained gates)
async with _checkpoints_lock:
_checkpoints[job_id] = result.checkpoint
async with _jobs_lock:
_jobs[job_id].update({
"status": "awaiting_confirmation",
"response": result.response,
"tool_calls": result.tool_calls,
"backend": result.backend,
"gemini_summary": result.gemini_summary,
"pending_confirmation": {
"tools": result.checkpoint.pending_tools,
"message": result.response,
},
})
logger.info("Orchestrator job %s awaiting another confirmation", job_id)
return
async with _jobs_lock:
session_id = _jobs[job_id].get("session_id") or ""
task = _jobs[job_id].get("task", "")
off_record = _jobs[job_id].get("_off_record", False)
from session_store import load as load_session
history = load_session(session_id) if session_id else []
await _finalize_job(job_id, result, session_id, task, history, off_record=off_record)
except Exception as e:
logger.exception("Orchestrator resume failed: %s", job_id)
now = datetime.now(timezone.utc).isoformat()
async with _jobs_lock:
_jobs[job_id].update({
"status": "error",
"completed_at": now,
"error": str(e),
})
async def _finalize_job(
job_id: str,
result: orchestrator_engine.OrchestratorResult,
session_id: str,
task: str,
history: list,
off_record: bool = False,
) -> None:
"""Save session, log the turn, and mark the job complete."""
from session_store import save as save_session, generate_session_id
from session_logger import log_turn
if not session_id:
session_id = generate_session_id()
host = platform.node()
history.append({"role": "user", "content": task, "off_record": off_record})
history.append({
"role": "assistant",
"content": result.response,
"backend": result.backend,
"backend_label": result.backend_label,
"host": host,
"off_record": off_record,
})
save_session(session_id, history)
if not off_record:
log_turn(session_id, task, result.response)
now = datetime.now(timezone.utc).isoformat()
async with _jobs_lock:
_jobs[job_id].update({
"status": "complete",
"completed_at": now,
"session_id": session_id,
"response": result.response,
"tool_calls": result.tool_calls,
"backend": result.backend,
"backend_label": result.backend_label,
"host": host,
"gemini_summary": result.gemini_summary,
})
q = _jobs[job_id].get("_event_queue")
logger.info("Orchestrator job complete: %s (%d tool calls)", job_id, len(result.tool_calls))
if q:
await q.put({
"type": "done",
"response": result.response,
"session_id": session_id,
"backend": result.backend,
"backend_label": result.backend_label or "",
"host": host,
"tool_calls": result.tool_calls,
})