feat: token streaming for orchestrator final response
Switches the orchestrator's final response from a fire-and-wait model to a
live SSE stream so text appears token-by-token as the model generates it.
- llm_client: complete() gains token_sink param; anthropic_api backend uses
client.messages.stream(); local backend uses httpx SSE streaming; non-streaming
backends (claude_cli, gemini_cli) emit the full text as one chunk
- orchestrator_engine + openai_orchestrator: token_sink threaded through run(),
_run_from_contents(), _claude_handoff(), and _run_from_messages()
- routers/orchestrator: each job gets an asyncio.Queue; _on_progress and
_token_sink write progress/token events to it; _finalize_job emits done,
error handler emits error, confirmation gate emits confirm; new GET
/orchestrate/{job_id}/stream SSE endpoint with 20s keepalive
- app.js: _doOrchestrate switches from 2s poll loop to EventSource; thinking
bubble converts to a streaming message on first token; auto-scroll while
streaming; confirm/error/done events handled; finalization unchanged
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -16,7 +16,8 @@ import platform
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from auth_utils import get_user_gemini_key, get_user_role, get_tool_policy, get_risk_policy
|
||||
@@ -116,6 +117,7 @@ async def orchestrate(req: OrchestrateRequest) -> OrchestrateResponse:
|
||||
"progress": None,
|
||||
"_user": user,
|
||||
"_off_record": req.off_record,
|
||||
"_event_queue": asyncio.Queue(),
|
||||
}
|
||||
|
||||
async with _jobs_lock:
|
||||
@@ -146,6 +148,45 @@ async def list_jobs() -> list[JobStatusResponse]:
|
||||
return [JobStatusResponse(**{k: v for k, v in j.items() if not k.startswith("_")}) for j in jobs]
|
||||
|
||||
|
||||
@router.get("/{job_id}/stream")
|
||||
async def stream_job(job_id: str, request: Request) -> StreamingResponse:
|
||||
"""SSE stream for a running job — emits progress, token, done, error, and confirm events."""
|
||||
import json
|
||||
|
||||
async with _jobs_lock:
|
||||
job = _jobs.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
||||
|
||||
# If already complete/error, emit a single done/error event immediately.
|
||||
if job["status"] == "complete":
|
||||
async def _done_now():
|
||||
yield f"data: {json.dumps({'type': 'done', 'response': job['response'], 'session_id': job.get('session_id'), 'backend': job.get('backend', ''), 'backend_label': job.get('backend_label', ''), 'host': job.get('host', ''), 'tool_calls': job.get('tool_calls')})}\n\n"
|
||||
return StreamingResponse(_done_now(), media_type="text/event-stream")
|
||||
if job["status"] == "error":
|
||||
async def _err_now():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': job.get('error', 'Unknown error')})}\n\n"
|
||||
return StreamingResponse(_err_now(), media_type="text/event-stream")
|
||||
|
||||
queue: asyncio.Queue = job["_event_queue"]
|
||||
|
||||
async def generate():
|
||||
yield 'data: {"type":"connected"}\n\n'
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=20)
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
if event["type"] in ("done", "error"):
|
||||
break
|
||||
# For confirm events: keep listening — job will resume after user action.
|
||||
except asyncio.TimeoutError:
|
||||
yield 'data: {"type":"keepalive"}\n\n'
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/{job_id}/confirm", response_model=OrchestrateResponse)
|
||||
async def confirm_job(job_id: str) -> OrchestrateResponse:
|
||||
"""Confirm a pending tool call — the blocked tool will execute and the job continues."""
|
||||
@@ -201,8 +242,18 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
|
||||
async def _on_progress(msg: str) -> None:
|
||||
async with _jobs_lock:
|
||||
if job_id in _jobs:
|
||||
_jobs[job_id]["progress"] = msg
|
||||
if job_id not in _jobs:
|
||||
return
|
||||
_jobs[job_id]["progress"] = msg
|
||||
q = _jobs[job_id].get("_event_queue")
|
||||
if q:
|
||||
await q.put({"type": "progress", "text": msg})
|
||||
|
||||
async def _token_sink(text: str) -> None:
|
||||
async with _jobs_lock:
|
||||
q = _jobs.get(job_id, {}).get("_event_queue")
|
||||
if q:
|
||||
await q.put({"type": "token", "text": text})
|
||||
|
||||
try:
|
||||
from session_store import load as load_session, save as save_session, generate_session_id
|
||||
@@ -248,6 +299,7 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
risk_whitelist=risk_wl,
|
||||
risk_blacklist=risk_bl,
|
||||
on_progress=_on_progress,
|
||||
token_sink=_token_sink,
|
||||
)
|
||||
else:
|
||||
gemini_key = (
|
||||
@@ -271,6 +323,7 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
risk_whitelist=risk_wl,
|
||||
risk_blacklist=risk_bl,
|
||||
on_progress=_on_progress,
|
||||
token_sink=_token_sink,
|
||||
)
|
||||
|
||||
if result.checkpoint:
|
||||
@@ -289,8 +342,15 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
"message": result.response,
|
||||
},
|
||||
})
|
||||
q = _jobs[job_id].get("_event_queue")
|
||||
logger.info("Orchestrator job %s awaiting confirmation — %d tool(s) blocked",
|
||||
job_id, len(result.checkpoint.pending_tools))
|
||||
if q:
|
||||
await q.put({
|
||||
"type": "confirm",
|
||||
"tools": result.checkpoint.pending_tools,
|
||||
"message": result.response,
|
||||
})
|
||||
return
|
||||
|
||||
await _finalize_job(job_id, result, session_id, req.task, history, off_record=req.off_record)
|
||||
@@ -304,6 +364,9 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
"completed_at": now,
|
||||
"error": str(e),
|
||||
})
|
||||
q = _jobs[job_id].get("_event_queue")
|
||||
if q:
|
||||
await q.put({"type": "error", "message": str(e)})
|
||||
|
||||
|
||||
async def _resume_job(
|
||||
@@ -400,4 +463,15 @@ async def _finalize_job(
|
||||
"host": host,
|
||||
"gemini_summary": result.gemini_summary,
|
||||
})
|
||||
q = _jobs[job_id].get("_event_queue")
|
||||
logger.info("Orchestrator job complete: %s (%d tool calls)", job_id, len(result.tool_calls))
|
||||
if q:
|
||||
await q.put({
|
||||
"type": "done",
|
||||
"response": result.response,
|
||||
"session_id": session_id,
|
||||
"backend": result.backend,
|
||||
"backend_label": result.backend_label or "",
|
||||
"host": host,
|
||||
"tool_calls": result.tool_calls,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user