feat: proper confirmation-resume flow + per-user tool policy
Fixes the broken confirmation gate where users had no way to approve
or deny a blocked tool call in the web UI.
Changes:
- orchestrator_engine.py: add OrchestrateCheckpoint dataclass, extract
loop into _run_from_contents(), add resume() function
- openai_orchestrator.py: same treatment — _run_from_messages(), resume()
- routers/orchestrator.py: POST /{job_id}/confirm and /deny endpoints,
separate _checkpoints store, _resume_job() + _finalize_job() helpers,
"awaiting_confirmation" job status with pending_confirmation payload
- auth_utils.py: get_tool_policy() and save_tool_policy() helpers reading
home/{user}/tool_policy.json (allow/deny lists)
- routers/orchestrator.py: loads tool_policy per user and passes
confirm_allow/confirm_deny to both engines
- app.js: poll loop handles awaiting_confirmation — shows Confirm/Deny
buttons inline, resumes polling after user action
- settings.html + settings.py: Tool Permissions section with allow/deny
textareas, POST /settings/tool-policy route
- style.css: .confirm-gate, .confirm-btn, .deny-btn styles
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,10 +15,10 @@ import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from auth_utils import get_user_gemini_key, get_user_role
|
||||
from auth_utils import get_user_gemini_key, get_user_role, get_tool_policy
|
||||
from config import settings
|
||||
from context_loader import load_context
|
||||
from persona import set_context, validate as validate_persona
|
||||
@@ -31,12 +31,16 @@ router = APIRouter(prefix="/orchestrate", tags=["orchestrator"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory job store
|
||||
# Jobs are keyed by UUID. For this phase, memory is fine — jobs are short-lived.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_jobs: dict[str, dict] = {}
|
||||
_jobs_lock = asyncio.Lock()
|
||||
|
||||
# Checkpoints are stored separately — they hold Python objects (types.Content, etc.)
|
||||
# that can't be included in the JSON-serializable job dict.
|
||||
_checkpoints: dict[str, orchestrator_engine.OrchestrateCheckpoint] = {}
|
||||
_checkpoints_lock = asyncio.Lock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
@@ -57,7 +61,7 @@ class OrchestrateRequest(BaseModel):
|
||||
|
||||
class OrchestrateResponse(BaseModel):
|
||||
job_id: str
|
||||
status: str # "queued" | "running" | "complete" | "error"
|
||||
status: str # "queued" | "running" | "complete" | "error" | "awaiting_confirmation"
|
||||
|
||||
|
||||
class JobStatusResponse(BaseModel):
|
||||
@@ -72,6 +76,7 @@ class JobStatusResponse(BaseModel):
|
||||
backend: str | None = None
|
||||
gemini_summary: str | None = None
|
||||
error: str | None = None
|
||||
pending_confirmation: dict | None = None # {tools: [{name, args}], message: str}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -85,7 +90,6 @@ async def orchestrate(req: OrchestrateRequest) -> OrchestrateResponse:
|
||||
user, persona = validate_persona(req.user, req.persona)
|
||||
set_context(user, persona)
|
||||
except ValueError as e:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
@@ -97,17 +101,19 @@ async def orchestrate(req: OrchestrateRequest) -> OrchestrateResponse:
|
||||
"task": req.task,
|
||||
"created_at": now,
|
||||
"completed_at": None,
|
||||
"session_id": None,
|
||||
"response": None,
|
||||
"tool_calls": None,
|
||||
"backend": None,
|
||||
"gemini_summary": None,
|
||||
"error": None,
|
||||
"pending_confirmation": None,
|
||||
"_user": user,
|
||||
}
|
||||
|
||||
async with _jobs_lock:
|
||||
_jobs[job_id] = job
|
||||
|
||||
# Run in background — caller polls GET /orchestrate/{job_id}
|
||||
asyncio.create_task(_run_job(job_id, req, user))
|
||||
logger.info("Orchestrator job queued: %s — %.80s", job_id, req.task)
|
||||
return OrchestrateResponse(job_id=job_id, status="queued")
|
||||
@@ -120,10 +126,9 @@ async def job_status(job_id: str) -> JobStatusResponse:
|
||||
job = _jobs.get(job_id)
|
||||
|
||||
if job is None:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
||||
|
||||
return JobStatusResponse(**job)
|
||||
return JobStatusResponse(**{k: v for k, v in job.items() if not k.startswith("_")})
|
||||
|
||||
|
||||
@router.get("", response_model=list[JobStatusResponse])
|
||||
@@ -131,11 +136,55 @@ async def list_jobs() -> list[JobStatusResponse]:
|
||||
"""List all jobs (most recent first). Useful for debugging."""
|
||||
async with _jobs_lock:
|
||||
jobs = sorted(_jobs.values(), key=lambda j: j["created_at"], reverse=True)
|
||||
return [JobStatusResponse(**j) for j in jobs]
|
||||
return [JobStatusResponse(**{k: v for k, v in j.items() if not k.startswith("_")}) for j in jobs]
|
||||
|
||||
|
||||
@router.post("/{job_id}/confirm", response_model=OrchestrateResponse)
|
||||
async def confirm_job(job_id: str) -> OrchestrateResponse:
|
||||
"""Confirm a pending tool call — the blocked tool will execute and the job continues."""
|
||||
async with _checkpoints_lock:
|
||||
checkpoint = _checkpoints.pop(job_id, None)
|
||||
|
||||
if checkpoint is None:
|
||||
raise HTTPException(status_code=404, detail="No pending confirmation for this job")
|
||||
|
||||
async with _jobs_lock:
|
||||
job = _jobs.get(job_id)
|
||||
if not job or job["status"] != "awaiting_confirmation":
|
||||
raise HTTPException(status_code=409, detail="Job is not awaiting confirmation")
|
||||
_jobs[job_id]["status"] = "running"
|
||||
_jobs[job_id]["pending_confirmation"] = None
|
||||
user = job.get("_user", "scott")
|
||||
|
||||
asyncio.create_task(_resume_job(job_id, checkpoint, confirmed=True, user=user))
|
||||
logger.info("Orchestrator job %s confirmed — resuming", job_id)
|
||||
return OrchestrateResponse(job_id=job_id, status="running")
|
||||
|
||||
|
||||
@router.post("/{job_id}/deny", response_model=OrchestrateResponse)
|
||||
async def deny_job(job_id: str) -> OrchestrateResponse:
|
||||
"""Deny a pending tool call — the tool is skipped and the job produces a final response."""
|
||||
async with _checkpoints_lock:
|
||||
checkpoint = _checkpoints.pop(job_id, None)
|
||||
|
||||
if checkpoint is None:
|
||||
raise HTTPException(status_code=404, detail="No pending confirmation for this job")
|
||||
|
||||
async with _jobs_lock:
|
||||
job = _jobs.get(job_id)
|
||||
if not job or job["status"] != "awaiting_confirmation":
|
||||
raise HTTPException(status_code=409, detail="Job is not awaiting confirmation")
|
||||
_jobs[job_id]["status"] = "running"
|
||||
_jobs[job_id]["pending_confirmation"] = None
|
||||
user = job.get("_user", "scott")
|
||||
|
||||
asyncio.create_task(_resume_job(job_id, checkpoint, confirmed=False, user=user))
|
||||
logger.info("Orchestrator job %s denied — resuming with skip", job_id)
|
||||
return OrchestrateResponse(job_id=job_id, status="running")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Background runner
|
||||
# Background runners
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
@@ -146,7 +195,6 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
try:
|
||||
from session_store import load as load_session, save as save_session, generate_session_id
|
||||
|
||||
# Load Inara's system prompt (same as the chat router does)
|
||||
tier = req.tier or settings.default_tier
|
||||
system_prompt = load_context(
|
||||
tier,
|
||||
@@ -155,16 +203,17 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
include_short=req.include_short,
|
||||
)
|
||||
|
||||
# Load session history if a session_id was provided
|
||||
session_id = req.session_id or generate_session_id()
|
||||
history = load_session(session_id)
|
||||
session_messages = history or None
|
||||
|
||||
# Choose engine based on the orchestrator role in the model registry
|
||||
orch_model = model_registry.get_model_for_role(user, "orchestrator")
|
||||
|
||||
user_role = get_user_role(user)
|
||||
|
||||
policy = get_tool_policy(user)
|
||||
confirm_allow = set(policy.get("allow", []))
|
||||
confirm_deny = set(policy.get("deny", []))
|
||||
|
||||
if orch_model and orch_model.get("type") == "local_openai":
|
||||
result = await openai_orchestrator.run(
|
||||
task=req.task,
|
||||
@@ -173,10 +222,10 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
model_cfg=orch_model,
|
||||
respond_with_final=req.respond_with_claude,
|
||||
user_role=user_role,
|
||||
confirm_allow=confirm_allow,
|
||||
confirm_deny=confirm_deny,
|
||||
)
|
||||
else:
|
||||
# Use the API key embedded in the resolved model config (V2 registry with
|
||||
# account_id), then fall back to the per-user key from auth.json, then .env.
|
||||
gemini_key = (
|
||||
(orch_model.get("api_key") if orch_model else None)
|
||||
or get_user_gemini_key(user)
|
||||
@@ -190,28 +239,31 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
model_name=orch_model.get("model_name") if orch_model else None,
|
||||
response_role=req.chat_role,
|
||||
user_role=user_role,
|
||||
confirm_allow=confirm_allow,
|
||||
confirm_deny=confirm_deny,
|
||||
)
|
||||
|
||||
# Save the turn to the session store so it survives a page refresh
|
||||
history.append({"role": "user", "content": req.task})
|
||||
history.append({"role": "assistant", "content": result.response})
|
||||
save_session(session_id, history)
|
||||
if result.checkpoint:
|
||||
async with _checkpoints_lock:
|
||||
_checkpoints[job_id] = result.checkpoint
|
||||
async with _jobs_lock:
|
||||
_jobs[job_id].update({
|
||||
"status": "awaiting_confirmation",
|
||||
"response": result.response,
|
||||
"tool_calls": result.tool_calls,
|
||||
"backend": result.backend,
|
||||
"gemini_summary": result.gemini_summary,
|
||||
"session_id": session_id,
|
||||
"pending_confirmation": {
|
||||
"tools": result.checkpoint.pending_tools,
|
||||
"message": result.response,
|
||||
},
|
||||
})
|
||||
logger.info("Orchestrator job %s awaiting confirmation — %d tool(s) blocked",
|
||||
job_id, len(result.checkpoint.pending_tools))
|
||||
return
|
||||
|
||||
from session_logger import log_turn
|
||||
log_turn(session_id, req.task, result.response)
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
async with _jobs_lock:
|
||||
_jobs[job_id].update({
|
||||
"status": "complete",
|
||||
"completed_at": now,
|
||||
"session_id": session_id,
|
||||
"response": result.response,
|
||||
"tool_calls": result.tool_calls,
|
||||
"backend": result.backend,
|
||||
"gemini_summary": result.gemini_summary,
|
||||
})
|
||||
logger.info("Orchestrator job complete: %s (%d tool calls)", job_id, len(result.tool_calls))
|
||||
await _finalize_job(job_id, result, session_id, req.task, history)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Orchestrator job failed: %s", job_id)
|
||||
@@ -222,3 +274,87 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None:
|
||||
"completed_at": now,
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
|
||||
async def _resume_job(
|
||||
job_id: str,
|
||||
checkpoint: orchestrator_engine.OrchestrateCheckpoint,
|
||||
confirmed: bool,
|
||||
user: str,
|
||||
) -> None:
|
||||
"""Resume a job after the user confirms or denies a pending tool call."""
|
||||
try:
|
||||
if checkpoint.engine == "gemini":
|
||||
result = await orchestrator_engine.resume(checkpoint, confirmed)
|
||||
else:
|
||||
result = await openai_orchestrator.resume(checkpoint, confirmed)
|
||||
|
||||
if result.checkpoint:
|
||||
# Another confirmation needed (chained gates)
|
||||
async with _checkpoints_lock:
|
||||
_checkpoints[job_id] = result.checkpoint
|
||||
async with _jobs_lock:
|
||||
_jobs[job_id].update({
|
||||
"status": "awaiting_confirmation",
|
||||
"response": result.response,
|
||||
"tool_calls": result.tool_calls,
|
||||
"backend": result.backend,
|
||||
"gemini_summary": result.gemini_summary,
|
||||
"pending_confirmation": {
|
||||
"tools": result.checkpoint.pending_tools,
|
||||
"message": result.response,
|
||||
},
|
||||
})
|
||||
logger.info("Orchestrator job %s awaiting another confirmation", job_id)
|
||||
return
|
||||
|
||||
async with _jobs_lock:
|
||||
session_id = _jobs[job_id].get("session_id") or ""
|
||||
task = _jobs[job_id].get("task", "")
|
||||
|
||||
from session_store import load as load_session
|
||||
history = load_session(session_id) if session_id else []
|
||||
await _finalize_job(job_id, result, session_id, task, history)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Orchestrator resume failed: %s", job_id)
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
async with _jobs_lock:
|
||||
_jobs[job_id].update({
|
||||
"status": "error",
|
||||
"completed_at": now,
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
|
||||
async def _finalize_job(
|
||||
job_id: str,
|
||||
result: orchestrator_engine.OrchestratorResult,
|
||||
session_id: str,
|
||||
task: str,
|
||||
history: list,
|
||||
) -> None:
|
||||
"""Save session, log the turn, and mark the job complete."""
|
||||
from session_store import save as save_session, generate_session_id
|
||||
from session_logger import log_turn
|
||||
|
||||
if not session_id:
|
||||
session_id = generate_session_id()
|
||||
|
||||
history.append({"role": "user", "content": task})
|
||||
history.append({"role": "assistant", "content": result.response})
|
||||
save_session(session_id, history)
|
||||
log_turn(session_id, task, result.response)
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
async with _jobs_lock:
|
||||
_jobs[job_id].update({
|
||||
"status": "complete",
|
||||
"completed_at": now,
|
||||
"session_id": session_id,
|
||||
"response": result.response,
|
||||
"tool_calls": result.tool_calls,
|
||||
"backend": result.backend,
|
||||
"gemini_summary": result.gemini_summary,
|
||||
})
|
||||
logger.info("Orchestrator job complete: %s (%d tool calls)", job_id, len(result.tool_calls))
|
||||
|
||||
@@ -18,7 +18,8 @@ import jwt
|
||||
from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
|
||||
from auth_utils import COOKIE_NAME, decode_token, check_credentials, set_password, _read_auth, _write_auth, get_user_channels
|
||||
from auth_utils import COOKIE_NAME, decode_token, check_credentials, set_password, _read_auth, _write_auth, get_user_channels, get_tool_policy, save_tool_policy
|
||||
from tools import CONFIRM_REQUIRED
|
||||
from persona import list_user_personas
|
||||
from config import settings as app_settings
|
||||
|
||||
@@ -84,6 +85,15 @@ def _settings_page(username: str, personas: list[str], back_persona: str = "", s
|
||||
html = html.replace("{{ nc_notify_room }}", nc_room)
|
||||
html = html.replace("{{ gc_webhook }}", gc_webhook)
|
||||
|
||||
# Tool permission policy
|
||||
policy = get_tool_policy(username)
|
||||
tool_allow_text = _html.escape("\n".join(policy.get("allow", [])))
|
||||
tool_deny_text = _html.escape("\n".join(policy.get("deny", [])))
|
||||
confirm_tools_list = _html.escape(", ".join(sorted(CONFIRM_REQUIRED)))
|
||||
html = html.replace("{{ tool_allow }}", tool_allow_text)
|
||||
html = html.replace("{{ tool_deny }}", tool_deny_text)
|
||||
html = html.replace("{{ confirm_required_tools }}", confirm_tools_list)
|
||||
|
||||
persona_items = "\n".join(
|
||||
f'''<li>
|
||||
<a href="/{username}/{p}" class="persona-link">{p}</a>
|
||||
@@ -302,6 +312,27 @@ async def save_notifications(
|
||||
success="Notification settings saved."))
|
||||
|
||||
|
||||
@router.post("/settings/tool-policy", include_in_schema=False)
|
||||
async def save_tool_policy_route(
|
||||
request: Request,
|
||||
allow_list: str = Form(""),
|
||||
deny_list: str = Form(""),
|
||||
):
|
||||
username = _get_session_user(request)
|
||||
if not username:
|
||||
return RedirectResponse("/login", status_code=302)
|
||||
|
||||
personas = list_user_personas(username)
|
||||
back_persona = _preferred_persona(request, username)
|
||||
|
||||
allow_tools = [ln.strip() for ln in allow_list.splitlines() if ln.strip()]
|
||||
deny_tools = [ln.strip() for ln in deny_list.splitlines() if ln.strip()]
|
||||
save_tool_policy(username, {"allow": allow_tools, "deny": deny_tools})
|
||||
logger.info("tool policy updated for %s (allow=%d deny=%d)", username, len(allow_tools), len(deny_tools))
|
||||
return HTMLResponse(_settings_page(username, personas, back_persona,
|
||||
success="Tool permission policy saved."))
|
||||
|
||||
|
||||
@router.post("/settings/email-allowlist", include_in_schema=False)
|
||||
async def save_email_allowlist(
|
||||
request: Request,
|
||||
|
||||
Reference in New Issue
Block a user