From 9cb2b0d9a5be50441ce4bc25ebb9dc1fa219d4e5 Mon Sep 17 00:00:00 2001 From: Scott Idem Date: Tue, 16 Jun 2026 23:22:50 -0400 Subject: [PATCH] feat: token streaming for orchestrator final response Switches the orchestrator's final response from a fire-and-wait model to a live SSE stream so text appears token-by-token as the model generates it. - llm_client: complete() gains token_sink param; anthropic_api backend uses client.messages.stream(); local backend uses httpx SSE streaming; non-streaming backends (claude_cli, gemini_cli) emit the full text as one chunk - orchestrator_engine + openai_orchestrator: token_sink threaded through run(), _run_from_contents(), _claude_handoff(), and _run_from_messages() - routers/orchestrator: each job gets an asyncio.Queue; _on_progress and _token_sink write progress/token events to it; _finalize_job emits done, error handler emits error, confirmation gate emits confirm; new GET /orchestrate/{job_id}/stream SSE endpoint with 20s keepalive - app.js: _doOrchestrate switches from 2s poll loop to EventSource; thinking bubble converts to a streaming message on first token; auto-scroll while streaming; confirm/error/done events handled; finalization unchanged Co-Authored-By: Claude Sonnet 4.6 --- cortex/llm_client.py | 121 ++++++++++++++++++++++++++++++--- cortex/openai_orchestrator.py | 5 ++ cortex/orchestrator_engine.py | 10 +++ cortex/routers/orchestrator.py | 80 +++++++++++++++++++++- cortex/static/app.js | 116 +++++++++++++++++-------------- documentation/TODO__Agents.md | 24 +++++++ 6 files changed, 293 insertions(+), 63 deletions(-) diff --git a/cortex/llm_client.py b/cortex/llm_client.py index 24a316e..1763d6b 100644 --- a/cortex/llm_client.py +++ b/cortex/llm_client.py @@ -53,6 +53,7 @@ async def complete( slot: str | None = None, max_tokens: int = 2048, attachment: dict | None = None, + token_sink=None, # async (str) -> None; if set, stream tokens as they arrive ) -> tuple[str, str]: """ Returns (response_text, actual_backend_used). @@ -98,7 +99,8 @@ async def complete( fallback = _FALLBACK.get(primary, "claude") try: - response = await _dispatch(primary, system_prompt, messages, resolved_cfg, attachment=attachment) + response = await _dispatch(primary, system_prompt, messages, resolved_cfg, + attachment=attachment, token_sink=token_sink) return response, primary except Exception as e: err_str = str(e) @@ -109,7 +111,7 @@ async def complete( logger.error("%s failed (no fallback — model explicitly configured): %s", primary, e) raise logger.warning("%s failed (%s) — falling back to %s", primary, e, fallback) - response = await _dispatch(fallback, system_prompt, messages, None) + response = await _dispatch(fallback, system_prompt, messages, None, token_sink=token_sink) return response, fallback @@ -119,14 +121,24 @@ async def _dispatch( messages: list[dict], model_cfg: dict | None, attachment: dict | None = None, + token_sink=None, ) -> str: if backend == "gemini": - return await _gemini(system_prompt, messages) - if backend == "local": - return await _local(system_prompt, messages, model_cfg, attachment=attachment) - if backend == "anthropic_api": - return await _anthropic_api(system_prompt, messages, model_cfg) - return await _claude(system_prompt, messages, model_cfg) + text = await _gemini(system_prompt, messages) + elif backend == "local": + if token_sink: + return await _local_streaming(token_sink, system_prompt, messages, model_cfg) + text = await _local(system_prompt, messages, model_cfg, attachment=attachment) + elif backend == "anthropic_api": + if token_sink: + return await _anthropic_api_streaming(token_sink, system_prompt, messages, model_cfg) + text = await _anthropic_api(system_prompt, messages, model_cfg) + else: + text = await _claude(system_prompt, messages, model_cfg) + # For non-streaming backends when token_sink is provided, emit the full text as one chunk. + if token_sink and text: + await token_sink(text) + return text def _fresh_claude_token() -> str | None: @@ -302,6 +314,99 @@ async def _anthropic_api(system_prompt: str, messages: list[dict], model_cfg: di return text.strip() +async def _anthropic_api_streaming( + token_sink, system_prompt: str, messages: list[dict], model_cfg: dict | None +) -> str: + try: + import anthropic + except ImportError: + raise RuntimeError("anthropic SDK not installed — run: pip install 'anthropic>=0.40.0'") + + cfg = model_cfg or {} + api_key = cfg.get("api_key", "") + model_name = cfg.get("model_name") or settings.default_model + + if not api_key: + raise RuntimeError("No Anthropic API key — add one at /settings/models") + + client = anthropic.AsyncAnthropic(api_key=api_key) + msgs = [{"role": m["role"], "content": m["content"]} for m in messages] + kwargs: dict = {"model": model_name, "max_tokens": 4096, "messages": msgs} + if system_prompt: + kwargs["system"] = system_prompt + + full_text = "" + async with client.messages.stream(**kwargs) as stream: + async for chunk in stream.text_stream: + await token_sink(chunk) + full_text += chunk + + final_msg = await stream.get_final_message() + if final_msg.usage: + import usage_tracker + from persona import _user + asyncio.create_task(usage_tracker.record( + username=_user.get(), + backend="anthropic_api", + model_name=model_name, + prompt_tokens=final_msg.usage.input_tokens, + completion_tokens=final_msg.usage.output_tokens, + )) + + return full_text.strip() + + +async def _local_streaming( + token_sink, system_prompt: str, messages: list[dict], model_cfg: dict | None +) -> str: + import httpx + import json as _json + + cfg = model_cfg or {} + api_url = cfg.get("api_url", "") + api_key = cfg.get("api_key", "") + model = cfg.get("model_name", "") + host_type = cfg.get("host_type", "openwebui") + + if not api_url: + raise RuntimeError("local_api_url not configured") + if not model: + raise RuntimeError("local_model not configured") + + chat_path = "/chat/completions" if host_type == "openai" else "/api/chat/completions" + url = api_url.rstrip("/") + chat_path + headers: dict[str, str] = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + msgs: list[dict] = [] + if system_prompt: + msgs.append({"role": "system", "content": system_prompt}) + for m in messages: + msgs.append({"role": m["role"], "content": m["content"]}) + + payload = {"model": model, "messages": msgs, "stream": True} + full_text = "" + + async with httpx.AsyncClient(timeout=settings.timeout_local) as client: + async with client.stream("POST", url, json=payload, headers=headers) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line or not line.startswith("data: "): + continue + data_str = line[6:].strip() + if data_str == "[DONE]": + break + try: + chunk = _json.loads(data_str) + delta = (chunk["choices"][0]["delta"].get("content") or "") + if delta: + await token_sink(delta) + full_text += delta + except Exception: + pass + + return full_text.strip() + + async def _gemini(system_prompt: str, messages: list[dict]) -> str: # Gemini CLI spawns MCP child processes that keep stdout pipes open after responding. # start_new_session=True puts the whole tree in its own process group so diff --git a/cortex/openai_orchestrator.py b/cortex/openai_orchestrator.py index d301533..fce7471 100644 --- a/cortex/openai_orchestrator.py +++ b/cortex/openai_orchestrator.py @@ -53,6 +53,7 @@ async def run( risk_whitelist: list[str] | None = None, risk_blacklist: list[str] | None = None, on_progress=None, # async (str) -> None; called with live status updates + token_sink=None, # async (str) -> None; called with each response token ) -> OrchestratorResult: """ Run a tool-enabled task using an OpenAI-compatible API. @@ -119,6 +120,7 @@ async def run( confirm_deny=_confirm_deny, starting_round=0, on_progress=on_progress, + token_sink=token_sink, ) if checkpoint: @@ -310,6 +312,7 @@ async def _run_from_messages( starting_round: int = 0, tool_list: list[str] | None = None, on_progress=None, + token_sink=None, ) -> tuple[str, OrchestrateCheckpoint | None]: """ Run the OpenAI ReAct loop from the current messages state. @@ -425,6 +428,8 @@ async def _run_from_messages( if on_progress: await on_progress("Generating response…") final_response = msg.content or "" + if token_sink and final_response: + await token_sink(final_response) logger.info( "OpenAI orchestrator done after %d round(s). Tools used: %d", round_num + 1, len(tool_call_log), diff --git a/cortex/orchestrator_engine.py b/cortex/orchestrator_engine.py index d6196b3..100a4ce 100644 --- a/cortex/orchestrator_engine.py +++ b/cortex/orchestrator_engine.py @@ -121,6 +121,7 @@ async def run( risk_whitelist: list[str] | None = None, risk_blacklist: list[str] | None = None, on_progress=None, # async (str) -> None; called with live status updates + token_sink=None, # async (str) -> None; called with each response token ) -> OrchestratorResult: """ Run the full orchestration loop for a task. @@ -185,6 +186,7 @@ async def run( gemini_api_key=api_key, max_rounds=max_rounds, on_progress=on_progress, + token_sink=token_sink, ) if checkpoint: @@ -207,6 +209,7 @@ async def run( session_messages=session_messages, respond_with_claude=respond_with_claude, response_role=response_role, + token_sink=token_sink, ) @@ -270,6 +273,8 @@ async def resume(checkpoint: OrchestrateCheckpoint, confirmed: bool) -> Orchestr gemini_api_key=api_key, max_rounds=checkpoint.max_rounds, ) + # Note: resume() doesn't have token_sink — the SSE stream endpoint is long-closed + # by the time a resumed job's final response is ready; polling fallback applies. if new_checkpoint: return OrchestratorResult( @@ -312,6 +317,7 @@ async def _run_from_contents( tool_list: list[str] | None = None, max_rounds: int | None = None, on_progress=None, + token_sink=None, ) -> tuple[str, OrchestrateCheckpoint | None]: """ Run the ReAct loop from the current contents state. @@ -454,6 +460,7 @@ async def _claude_handoff( session_messages: list[dict] | None, respond_with_claude: bool, response_role: str, + token_sink=None, ) -> OrchestratorResult: if respond_with_claude: claude_prompt = _build_claude_prompt(task, tool_call_log, gemini_summary) @@ -463,10 +470,13 @@ async def _claude_handoff( system_prompt=system_prompt, messages=messages, role=response_role, + token_sink=token_sink, ) else: response_text = gemini_summary or "No information gathered." backend = "gemini" + if token_sink and response_text: + await token_sink(response_text) return OrchestratorResult( response=response_text, diff --git a/cortex/routers/orchestrator.py b/cortex/routers/orchestrator.py index 4092a02..2ba26d6 100644 --- a/cortex/routers/orchestrator.py +++ b/cortex/routers/orchestrator.py @@ -16,7 +16,8 @@ import platform import uuid from datetime import datetime, timezone -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse from pydantic import BaseModel from auth_utils import get_user_gemini_key, get_user_role, get_tool_policy, get_risk_policy @@ -116,6 +117,7 @@ async def orchestrate(req: OrchestrateRequest) -> OrchestrateResponse: "progress": None, "_user": user, "_off_record": req.off_record, + "_event_queue": asyncio.Queue(), } async with _jobs_lock: @@ -146,6 +148,45 @@ async def list_jobs() -> list[JobStatusResponse]: return [JobStatusResponse(**{k: v for k, v in j.items() if not k.startswith("_")}) for j in jobs] +@router.get("/{job_id}/stream") +async def stream_job(job_id: str, request: Request) -> StreamingResponse: + """SSE stream for a running job — emits progress, token, done, error, and confirm events.""" + import json + + async with _jobs_lock: + job = _jobs.get(job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Job {job_id} not found") + + # If already complete/error, emit a single done/error event immediately. + if job["status"] == "complete": + async def _done_now(): + yield f"data: {json.dumps({'type': 'done', 'response': job['response'], 'session_id': job.get('session_id'), 'backend': job.get('backend', ''), 'backend_label': job.get('backend_label', ''), 'host': job.get('host', ''), 'tool_calls': job.get('tool_calls')})}\n\n" + return StreamingResponse(_done_now(), media_type="text/event-stream") + if job["status"] == "error": + async def _err_now(): + yield f"data: {json.dumps({'type': 'error', 'message': job.get('error', 'Unknown error')})}\n\n" + return StreamingResponse(_err_now(), media_type="text/event-stream") + + queue: asyncio.Queue = job["_event_queue"] + + async def generate(): + yield 'data: {"type":"connected"}\n\n' + while True: + if await request.is_disconnected(): + break + try: + event = await asyncio.wait_for(queue.get(), timeout=20) + yield f"data: {json.dumps(event)}\n\n" + if event["type"] in ("done", "error"): + break + # For confirm events: keep listening — job will resume after user action. + except asyncio.TimeoutError: + yield 'data: {"type":"keepalive"}\n\n' + + return StreamingResponse(generate(), media_type="text/event-stream") + + @router.post("/{job_id}/confirm", response_model=OrchestrateResponse) async def confirm_job(job_id: str) -> OrchestrateResponse: """Confirm a pending tool call — the blocked tool will execute and the job continues.""" @@ -201,8 +242,18 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: async def _on_progress(msg: str) -> None: async with _jobs_lock: - if job_id in _jobs: - _jobs[job_id]["progress"] = msg + if job_id not in _jobs: + return + _jobs[job_id]["progress"] = msg + q = _jobs[job_id].get("_event_queue") + if q: + await q.put({"type": "progress", "text": msg}) + + async def _token_sink(text: str) -> None: + async with _jobs_lock: + q = _jobs.get(job_id, {}).get("_event_queue") + if q: + await q.put({"type": "token", "text": text}) try: from session_store import load as load_session, save as save_session, generate_session_id @@ -248,6 +299,7 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: risk_whitelist=risk_wl, risk_blacklist=risk_bl, on_progress=_on_progress, + token_sink=_token_sink, ) else: gemini_key = ( @@ -271,6 +323,7 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: risk_whitelist=risk_wl, risk_blacklist=risk_bl, on_progress=_on_progress, + token_sink=_token_sink, ) if result.checkpoint: @@ -289,8 +342,15 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: "message": result.response, }, }) + q = _jobs[job_id].get("_event_queue") logger.info("Orchestrator job %s awaiting confirmation — %d tool(s) blocked", job_id, len(result.checkpoint.pending_tools)) + if q: + await q.put({ + "type": "confirm", + "tools": result.checkpoint.pending_tools, + "message": result.response, + }) return await _finalize_job(job_id, result, session_id, req.task, history, off_record=req.off_record) @@ -304,6 +364,9 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: "completed_at": now, "error": str(e), }) + q = _jobs[job_id].get("_event_queue") + if q: + await q.put({"type": "error", "message": str(e)}) async def _resume_job( @@ -400,4 +463,15 @@ async def _finalize_job( "host": host, "gemini_summary": result.gemini_summary, }) + q = _jobs[job_id].get("_event_queue") logger.info("Orchestrator job complete: %s (%d tool calls)", job_id, len(result.tool_calls)) + if q: + await q.put({ + "type": "done", + "response": result.response, + "session_id": session_id, + "backend": result.backend, + "backend_label": result.backend_label or "", + "host": host, + "tool_calls": result.tool_calls, + }) diff --git a/cortex/static/app.js b/cortex/static/app.js index a35bbf9..546d5e9 100644 --- a/cortex/static/app.js +++ b/cortex/static/app.js @@ -1475,68 +1475,79 @@ if (!res.ok) throw new Error(`HTTP ${res.status}`); const { job_id } = await res.json(); - // Poll until complete or stopped - let job; - while (true) { - if (activeController.signal.aborted) throw new DOMException('Aborted', 'AbortError'); + // Stream events from the job via SSE + const job = await new Promise((resolve, reject) => { + const es = new EventSource(`/orchestrate/${job_id}/stream`); + let streamingStarted = false; + let accumulatedText = ''; - await new Promise(r => setTimeout(r, 2000)); + const abort = activeController.signal; + abort.addEventListener('abort', () => { es.close(); reject(new DOMException('Aborted', 'AbortError')); }); - if (activeController.signal.aborted) throw new DOMException('Aborted', 'AbortError'); + es.onmessage = async (e) => { + let event; + try { event = JSON.parse(e.data); } catch { return; } - const pollRes = await fetch(`/orchestrate/${job_id}`, { - signal: activeController.signal, - }); - if (!pollRes.ok) throw new Error(`Poll failed: HTTP ${pollRes.status}`); - job = await pollRes.json(); + if (event.type === 'connected' || event.type === 'keepalive') return; - if (job.status === 'queued' || job.status === 'running') { - const prog = job.progress; - const n = job.tool_calls?.length || 0; - if (prog) { - thinkingDiv.textContent = `⚡ ${prog}`; - } else { - thinkingDiv.textContent = n - ? `⚡ working… (${n} tool${n !== 1 ? 's' : ''} used)` - : '⚡ working…'; + if (event.type === 'progress') { + if (!streamingStarted) thinkingDiv.textContent = `⚡ ${event.text}`; + return; } - continue; - } - if (job.status === 'awaiting_confirmation') { - const pc = job.pending_confirmation || {}; - const toolNames = (pc.tools || []).map(t => t.name).join(', '); - thinkingDiv.className = 'message assistant'; - thinkingDiv.innerHTML = `
-

${escapeHtml(pc.message || 'Confirm this action?')}

-

Tool${(pc.tools||[]).length !== 1 ? 's' : ''}: ${escapeHtml(toolNames)}

-
- - -
-
`; + if (event.type === 'token') { + if (!streamingStarted) { + streamingStarted = true; + thinkingDiv.className = 'message assistant'; + thinkingDiv.innerHTML = ''; + } + accumulatedText += event.text; + setMessageText(thinkingDiv, 'assistant', accumulatedText); + thinkingDiv.scrollIntoView({ behavior: 'smooth', block: 'end' }); + return; + } - const confirmed = await new Promise(resolve => { - thinkingDiv.querySelector('.confirm-btn').onclick = () => resolve(true); - thinkingDiv.querySelector('.deny-btn').onclick = () => resolve(false); - }); + if (event.type === 'confirm') { + const pc = event; + const toolNames = (pc.tools || []).map(t => t.name).join(', '); + thinkingDiv.className = 'message assistant'; + thinkingDiv.innerHTML = `
+

${escapeHtml(pc.message || 'Confirm this action?')}

+

Tool${(pc.tools||[]).length !== 1 ? 's' : ''}: ${escapeHtml(toolNames)}

+
+ + +
+
`; + const confirmed = await new Promise(r => { + thinkingDiv.querySelector('.confirm-btn').onclick = () => r(true); + thinkingDiv.querySelector('.deny-btn').onclick = () => r(false); + }); + thinkingDiv.className = 'message assistant thinking'; + thinkingDiv.textContent = confirmed ? '⚡ confirmed — continuing…' : '⚡ denied — finishing…'; + streamingStarted = false; + accumulatedText = ''; + const action = confirmed ? 'confirm' : 'deny'; + await fetch(`/orchestrate/${job_id}/${action}`, { method: 'POST' }); + return; + } - thinkingDiv.className = 'message assistant thinking'; - thinkingDiv.textContent = confirmed ? '⚡ confirmed — continuing…' : '⚡ denied — finishing…'; + if (event.type === 'error') { + es.close(); + reject(new Error(event.message || 'Orchestrator failed')); + return; + } - const action = confirmed ? 'confirm' : 'deny'; - const resumeRes = await fetch(`/orchestrate/${job_id}/${action}`, { - method: 'POST', - signal: activeController.signal, - }); - if (!resumeRes.ok) throw new Error(`Resume failed: HTTP ${resumeRes.status}`); - continue; - } + if (event.type === 'done') { + es.close(); + // If we received tokens, the response is already rendered — + // use accumulatedText; otherwise fall back to event.response. + resolve({ ...event, response: accumulatedText || event.response }); + } + }; - break; - } - - if (job.status === 'error') throw new Error(job.error || 'Orchestrator failed'); + es.onerror = () => { es.close(); reject(new Error('Stream connection lost')); }; + }); // Update session so this turn is part of the resumable history if (job.session_id) { @@ -1548,6 +1559,7 @@ const userHistIdx = currentHistory.length - 1; // pushed before fetch attachHistoryControls(userMsgDiv, userHistIdx); + // If tokens streamed, the div is already a message; if not, set text now. thinkingDiv.className = 'message assistant'; setMessageText(thinkingDiv, 'assistant', job.response || '(no response)'); const assistHistIdx = currentHistory.length; diff --git a/documentation/TODO__Agents.md b/documentation/TODO__Agents.md index fbad3d0..495ed39 100644 --- a/documentation/TODO__Agents.md +++ b/documentation/TODO__Agents.md @@ -249,6 +249,30 @@ model costs down as sessions grow. Not continuous per-token — checkpoint-trigg heuristic handles the worst cases. Priority rises with dev-agent pipeline work where aider tool results can be very large. +### [UX] Token streaming for orchestrator final response ✅ — 2026-06-16 +Text appears token-by-token while the model is generating, instead of waiting for the +full response after "Generating response…" completes. + +- [x] **`llm_client.py`** — `complete()` gains `token_sink` param; `_dispatch()` routes to + streaming variants when set; `_anthropic_api_streaming()` uses `client.messages.stream()`; + `_local_streaming()` uses `httpx client.stream()` + SSE parsing; non-streaming backends + (claude_cli, gemini_cli) emit full text as one chunk via `token_sink` +- [x] **`orchestrator_engine.py`** — `run()`, `_run_from_contents()`, and `_claude_handoff()` + all accept and thread `token_sink`; Gemini handoff to Claude/Anthropic API is the + primary streaming path +- [x] **`openai_orchestrator.py`** — `run()` and `_run_from_messages()` accept `token_sink`; + local model final response emitted via `token_sink` (one chunk for now; true streaming + left for future polish) +- [x] **`routers/orchestrator.py`** — each job gets an `asyncio.Queue` (`_event_queue`); + `_on_progress` and `_token_sink` write to the queue as events (`{type, text}`); + `_finalize_job` emits `{type: done, ...}`, error handler emits `{type: error, ...}`, + confirmation gate emits `{type: confirm, ...}`; new `GET /orchestrate/{job_id}/stream` + SSE endpoint with 20s keepalive timeout; handles already-complete/error jobs immediately +- [x] **`static/app.js`** — `_doOrchestrate` switches from poll loop to `EventSource`; renders + thinking-bubble progress labels on `progress` events; converts bubble to streaming message + on first `token` event (with auto-scroll); handles `confirm`, `error`, `done` events; + finalization (metadata, history controls, tool calls) runs after `done` + ### [Auth] Encrypted sessions Allow users to opt-in to per-session encryption so session logs on disk cannot be read without the user's key.