diff --git a/cortex/auth_utils.py b/cortex/auth_utils.py index e78b054..4d5dbb2 100644 --- a/cortex/auth_utils.py +++ b/cortex/auth_utils.py @@ -224,3 +224,22 @@ def get_user_channels(username: str) -> dict: return json.loads(path.read_text()) except Exception: return {} + + +def get_tool_policy(username: str) -> dict: + """Return the parsed tool_policy.json for a user. + + Keys: + allow — tools in CONFIRM_REQUIRED that this user has pre-approved (skip gate) + deny — tools always blocked for this user regardless of global CONFIRM_REQUIRED + """ + path = settings.home_root() / username / "tool_policy.json" + try: + return json.loads(path.read_text()) + except Exception: + return {} + + +def save_tool_policy(username: str, data: dict) -> None: + path = settings.home_root() / username / "tool_policy.json" + path.write_text(json.dumps(data, indent=2) + "\n") diff --git a/cortex/openai_orchestrator.py b/cortex/openai_orchestrator.py index d8caa38..0698c37 100644 --- a/cortex/openai_orchestrator.py +++ b/cortex/openai_orchestrator.py @@ -24,8 +24,8 @@ import logging from openai import AsyncOpenAI from config import settings -from orchestrator_engine import OrchestratorResult -from tools import OPENAI_TOOL_SCHEMAS, call_tool, get_openai_tools_for_role, CONFIRM_REQUIRED +from orchestrator_engine import OrchestrateCheckpoint, OrchestratorResult +from tools import OPENAI_TOOL_SCHEMAS, call_tool, get_openai_tools_for_role, get_tools_for_role, CONFIRM_REQUIRED logger = logging.getLogger(__name__) @@ -45,6 +45,8 @@ async def run( model_cfg: dict | None = None, respond_with_final: bool = True, user_role: str = "user", + confirm_allow: set[str] | None = None, + confirm_deny: set[str] | None = None, ) -> OrchestratorResult: """ Run a tool-enabled task using an OpenAI-compatible API. @@ -56,36 +58,22 @@ async def run( model_cfg: Resolved model config from model_registry (local_openai type) respond_with_final: If False, return just the tool-loop summary without a full persona-voiced response (faster; for cron/background) + confirm_allow: Tools to bypass the confirmation gate for this user + confirm_deny: Tools to always block for this user Returns: - OrchestratorResult — same shape as the Gemini engine for drop-in compatibility + OrchestratorResult — if checkpoint is set, the job is awaiting confirmation """ if not model_cfg: raise RuntimeError("model_cfg is required for the OpenAI orchestrator") - api_url = model_cfg.get("api_url", "") - api_key = model_cfg.get("api_key", "") or "none" - model_name = model_cfg.get("model_name", "") - host_type = model_cfg.get("host_type", "openwebui") + _confirm_allow = frozenset(confirm_allow or ()) + _confirm_deny = frozenset(confirm_deny or ()) + effective_confirm = (CONFIRM_REQUIRED - set(_confirm_allow)) | set(_confirm_deny) - if not api_url or not model_name: - raise RuntimeError( - f"model_cfg missing api_url or model_name: {model_cfg.get('label', model_cfg)}" - ) + client, model_name, active_tools = _build_client(model_cfg) - # Open WebUI's OpenAI-compatible endpoint lives at /api/chat/completions, - # so the SDK base_url needs the /api prefix; standard OpenAI-layout hosts don't. - base_url = api_url.rstrip("/") - if host_type == "openwebui": - base_url = base_url + "/api" - - client = AsyncOpenAI(base_url=base_url, api_key=api_key) - - # System prompt: persona context + brief tool instruction sys_content = (system_prompt or "") + _TOOL_INSTRUCTION - - # Build messages: [system, ...recent_session, current_task] - # Strip non-standard metadata fields (backend, host, etc.) before sending. messages: list[dict] = [{"role": "system", "content": sys_content}] if session_messages: messages.extend( @@ -94,13 +82,132 @@ async def run( ) messages.append({"role": "user", "content": task}) - active_tools = get_openai_tools_for_role(user_role) - active_callables: dict | None = None # resolved lazily below - tool_call_log: list[dict] = [] + + final_response, checkpoint = await _run_from_messages( + client=client, + messages=messages, + active_tools=active_tools, + tool_call_log=tool_call_log, + effective_confirm=effective_confirm, + model_name=model_name, + task=task, + model_cfg=model_cfg, + respond_with_final=respond_with_final, + user_role=user_role, + confirm_allow=_confirm_allow, + confirm_deny=_confirm_deny, + starting_round=0, + ) + + if checkpoint: + return OrchestratorResult( + response=final_response, + tool_calls=list(tool_call_log), + backend="local", + gemini_summary=final_response, + checkpoint=checkpoint, + ) + + model_label = model_cfg.get("label") or model_name + logger.info("OpenAI orchestrator complete — model=%s tools=%d", model_label, len(tool_call_log)) + return OrchestratorResult( + response=final_response, + tool_calls=tool_call_log, + backend="local", + gemini_summary=final_response, + ) + + +async def resume(checkpoint: OrchestrateCheckpoint, confirmed: bool) -> OrchestratorResult: + """Continue an OpenAI orchestrator job that was paused at a confirmation gate.""" + client, model_name, active_tools = _build_client(checkpoint.model_cfg) + + effective_confirm = (CONFIRM_REQUIRED - set(checkpoint.confirm_allow)) | set(checkpoint.confirm_deny) + + messages = list(checkpoint.pre_fn_state) + tool_call_log = [t for t in checkpoint.tool_call_log if t["result"] != "[awaiting confirmation]"] + + # Build tool responses for this round + for er in checkpoint.executed_results: + messages.append({ + "role": "tool", + "tool_call_id": er.get("tool_call_id", er["name"]), + "content": er["result"], + }) + + for pt in checkpoint.pending_tools: + if confirmed: + _, callables = get_tools_for_role(checkpoint.user_role) + result_str = await _execute_tool_dict(pt["name"], pt["args"], checkpoint.user_role) + logger.info("Confirmed tool %s → %d chars", pt["name"], len(result_str)) + else: + result_str = "Action denied by user." + logger.info("Tool %s denied by user", pt["name"]) + tool_call_log.append({"tool": pt["name"], "args": pt["args"], "result": result_str}) + messages.append({ + "role": "tool", + "tool_call_id": pt.get("tool_call_id", pt["name"]), + "content": result_str, + }) + + final_response, new_checkpoint = await _run_from_messages( + client=client, + messages=messages, + active_tools=active_tools, + tool_call_log=tool_call_log, + effective_confirm=effective_confirm, + model_name=model_name, + task=checkpoint.task, + model_cfg=checkpoint.model_cfg, + respond_with_final=checkpoint.respond_with_final, + user_role=checkpoint.user_role, + confirm_allow=checkpoint.confirm_allow, + confirm_deny=checkpoint.confirm_deny, + starting_round=checkpoint.rounds_used, + ) + + if new_checkpoint: + return OrchestratorResult( + response=final_response, + tool_calls=list(tool_call_log), + backend="local", + gemini_summary=final_response, + checkpoint=new_checkpoint, + ) + + model_label = (checkpoint.model_cfg or {}).get("label") or model_name + logger.info("OpenAI orchestrator resumed — model=%s tools=%d", model_label, len(tool_call_log)) + return OrchestratorResult( + response=final_response, + tool_calls=tool_call_log, + backend="local", + gemini_summary=final_response, + ) + + +async def _run_from_messages( + client, + messages: list[dict], + active_tools: list, + tool_call_log: list[dict], + effective_confirm: set[str], + model_name: str, + task: str, + model_cfg: dict | None, + respond_with_final: bool, + user_role: str, + confirm_allow: frozenset, + confirm_deny: frozenset, + starting_round: int = 0, +) -> tuple[str, OrchestrateCheckpoint | None]: + """ + Run the OpenAI ReAct loop from the current messages state. + Returns (final_response, checkpoint) — checkpoint is set if confirmation is needed. + """ final_response = "" - for round_num in range(settings.orchestrator_max_rounds): + for round_num in range(starting_round, settings.orchestrator_max_rounds): logger.info("OpenAI orchestrator round %d / %d model=%s", round_num + 1, settings.orchestrator_max_rounds, model_name) @@ -112,29 +219,28 @@ async def run( ) choice = response.choices[0] - msg = choice.message + msg = choice.message - # Append the assistant turn (MUST include tool_calls if present so the - # next request is valid — OpenAI requires the full history to be consistent) assistant_msg: dict = {"role": "assistant"} if msg.content: assistant_msg["content"] = msg.content if msg.tool_calls: assistant_msg["tool_calls"] = [ { - "id": tc.id, + "id": tc.id, "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, + "function": {"name": tc.function.name, "arguments": tc.function.arguments}, } for tc in msg.tool_calls ] messages.append(assistant_msg) if choice.finish_reason == "tool_calls" and msg.tool_calls: - confirm_requested = False + # Snapshot state before tool responses for potential checkpoint + pre_fn_state = list(messages) + + pending_tools: list[dict] = [] + executed_results: list[dict] = [] for tc in msg.tool_calls: name = tc.function.name @@ -143,34 +249,23 @@ async def run( except json.JSONDecodeError: args_parsed = {"raw": tc.function.arguments} - if name in CONFIRM_REQUIRED: - args_str = json.dumps(args_parsed, indent=2) if args_parsed else "(no arguments)" - result_str = ( - f"⚠️ CONFIRMATION REQUIRED ⚠️\n" - f"Tool: {name}\nArguments:\n{args_str}\n\n" - f"Do NOT call this tool again. Tell the user exactly what you were " - f"about to do, explain the potential impact, and ask them to confirm " - f"by sending a follow-up message before you proceed." - ) - confirm_requested = True + if name in effective_confirm: + pending_tools.append({"name": name, "args": args_parsed, "tool_call_id": tc.id}) logger.info("Tool %s blocked — confirmation required", name) else: result_str = await _execute_tool(name, tc.function.arguments, user_role) logger.info("Tool %s → %d chars", name, len(result_str)) + executed_results.append({"name": name, "args": args_parsed, "result": result_str, "tool_call_id": tc.id}) + tool_call_log.append({"tool": name, "args": args_parsed, "result": result_str}) + messages.append({"role": "tool", "tool_call_id": tc.id, "content": result_str}) - tool_call_log.append({ - "tool": name, - "args": args_parsed, - "result": "[awaiting confirmation]" if name in CONFIRM_REQUIRED else result_str, - }) - messages.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": result_str, - }) + if pending_tools: + # Add placeholder responses + for pt in pending_tools: + placeholder = f"[AWAITING USER CONFIRMATION for {pt['name']}]" + tool_call_log.append({"tool": pt["name"], "args": pt["args"], "result": "[awaiting confirmation]"}) + messages.append({"role": "tool", "tool_call_id": pt["tool_call_id"], "content": placeholder}) - if confirm_requested: - # One more model round to produce the confirmation-request message, then stop. conf_resp = await client.chat.completions.create( model=model_name, messages=messages, @@ -180,10 +275,24 @@ async def run( final_response = conf_resp.choices[0].message.content or ( "This action requires your explicit confirmation before it can proceed." ) - break + + checkpoint = OrchestrateCheckpoint( + engine="openai", + pre_fn_state=pre_fn_state, + executed_results=executed_results, + pending_tools=pending_tools, + tool_call_log=list(tool_call_log), + task=task, + model_cfg=model_cfg, + respond_with_final=respond_with_final, + user_role=user_role, + confirm_allow=confirm_allow, + confirm_deny=confirm_deny, + rounds_used=round_num + 2, + ) + return final_response, checkpoint else: - # finish_reason == "stop" (or no tool_calls) — model is done final_response = msg.content or "" logger.info( "OpenAI orchestrator done after %d round(s). Tools used: %d", @@ -192,30 +301,37 @@ async def run( break else: - # Hit the round limit logger.warning("OpenAI orchestrator hit max rounds (%d)", settings.orchestrator_max_rounds) final_response = ( f"Reached the tool iteration limit ({settings.orchestrator_max_rounds} rounds). " "Here is what was gathered:\n\n" - + "\n\n".join( - f"**{t['tool']}**: {t['result'][:500]}" for t in tool_call_log - ) + + "\n\n".join(f"**{t['tool']}**: {t['result'][:500]}" for t in tool_call_log) ) - model_label = model_cfg.get("label") or model_name - logger.info("OpenAI orchestrator complete — model=%s tools=%d", model_label, len(tool_call_log)) + return final_response, None - return OrchestratorResult( - response=final_response, - tool_calls=tool_call_log, - backend="local", - gemini_summary=final_response, # reused for UI display; same content in single-model mode - ) + +def _build_client(model_cfg: dict | None) -> tuple: + """Build AsyncOpenAI client and return (client, model_name, active_tools).""" + if not model_cfg: + raise RuntimeError("model_cfg is required for the OpenAI orchestrator") + api_url = model_cfg.get("api_url", "") + api_key = model_cfg.get("api_key", "") or "none" + model_name = model_cfg.get("model_name", "") + host_type = model_cfg.get("host_type", "openwebui") + if not api_url or not model_name: + raise RuntimeError( + f"model_cfg missing api_url or model_name: {model_cfg.get('label', model_cfg)}" + ) + base_url = api_url.rstrip("/") + if host_type == "openwebui": + base_url = base_url + "/api" + client = AsyncOpenAI(base_url=base_url, api_key=api_key) + return client, model_name, OPENAI_TOOL_SCHEMAS async def _execute_tool(name: str, arguments_json: str, user_role: str = "user") -> str: """Parse tool arguments and execute with role-filtered callables.""" - from tools import get_tools_for_role _, callables = get_tools_for_role(user_role) try: args = json.loads(arguments_json) @@ -226,3 +342,13 @@ async def _execute_tool(name: str, arguments_json: str, user_role: str = "user") except Exception as e: logger.warning("Tool %s failed: %s", name, e) return f"Tool error: {e}" + + +async def _execute_tool_dict(name: str, args: dict, user_role: str = "user") -> str: + """Execute a tool from a pre-parsed args dict.""" + _, callables = get_tools_for_role(user_role) + try: + return await call_tool(name, args, callables) + except Exception as e: + logger.warning("Tool %s failed: %s", name, e) + return f"Tool error: {e}" diff --git a/cortex/orchestrator_engine.py b/cortex/orchestrator_engine.py index f9afd13..c40466a 100644 --- a/cortex/orchestrator_engine.py +++ b/cortex/orchestrator_engine.py @@ -44,12 +44,39 @@ Keep your summary factual and complete. Include relevant URLs, data, and specifi If no tools are needed, return an empty summary.""" +@dataclass +class OrchestrateCheckpoint: + """Saved execution state for a job paused at a confirmation gate.""" + engine: str # "gemini" | "openai" + pre_fn_state: list # conversation state before function responses + executed_results: list[dict] # tools that already ran this round + pending_tools: list[dict] # [{name, args}] awaiting confirmation + tool_call_log: list[dict] # all tool calls so far + task: str + # Gemini-specific config (unused by openai engine) + system_prompt: str = "" + session_messages: list | None = None + model_name: str | None = None + gemini_api_key: str | None = None + respond_with_claude: bool = True + response_role: str = "chat" + # OpenAI-specific config (unused by gemini engine) + model_cfg: dict | None = None + respond_with_final: bool = True + # Common + user_role: str = "user" + confirm_allow: frozenset = field(default_factory=frozenset) + confirm_deny: frozenset = field(default_factory=frozenset) + rounds_used: int = 0 + + @dataclass class OrchestratorResult: response: str # final user-facing response (from Claude) tool_calls: list[dict] = field(default_factory=list) # [{tool, args, result}] backend: str = "claude" # model that produced the final response gemini_summary: str = "" # what Gemini handed to Claude (debug/display) + checkpoint: OrchestrateCheckpoint | None = None # set when awaiting confirmation async def run( @@ -61,6 +88,8 @@ async def run( model_name: str | None = None, response_role: str = "chat", user_role: str = "user", + confirm_allow: set[str] | None = None, + confirm_deny: set[str] | None = None, ) -> OrchestratorResult: """ Run the full orchestration loop for a task. @@ -72,9 +101,11 @@ async def run( respond_with_claude: If False, return Gemini's summary as the response (useful for background/cron tasks where a polished reply isn't needed) gemini_api_key: Per-user Gemini API key (falls back to GEMINI_API_KEY in .env) + confirm_allow: Tools to bypass the confirmation gate for this user + confirm_deny: Tools to always block for this user Returns: - OrchestratorResult with response, tool call log, backend used, and Gemini summary + OrchestratorResult — if checkpoint is set, the job is awaiting confirmation """ api_key = gemini_api_key or settings.gemini_api_key if not api_key: @@ -85,19 +116,157 @@ async def run( client = genai.Client(api_key=api_key) - # Seed Gemini with the task — include recent session context if available + _confirm_allow = frozenset(confirm_allow or ()) + _confirm_deny = frozenset(confirm_deny or ()) + effective_confirm = (CONFIRM_REQUIRED - set(_confirm_allow)) | set(_confirm_deny) + task_with_context = _build_task_prompt(task, session_messages) contents: list[types.Content] = [ types.Content(role="user", parts=[types.Part(text=task_with_context)]) ] - tool_declarations, tool_callables = get_tools_for_role(user_role) - tool_call_log: list[dict] = [] + + gemini_summary, checkpoint = await _run_from_contents( + client=client, + contents=contents, + tool_declarations=tool_declarations, + tool_callables=tool_callables, + tool_call_log=tool_call_log, + effective_confirm=effective_confirm, + model_name=model_name, + task=task, + system_prompt=system_prompt, + session_messages=session_messages, + respond_with_claude=respond_with_claude, + response_role=response_role, + user_role=user_role, + confirm_allow=_confirm_allow, + confirm_deny=_confirm_deny, + starting_round=0, + gemini_api_key=api_key, + ) + + if checkpoint: + return OrchestratorResult( + response=gemini_summary, + tool_calls=list(tool_call_log), + backend="gemini", + gemini_summary=gemini_summary, + checkpoint=checkpoint, + ) + + return await _claude_handoff( + task=task, + tool_call_log=tool_call_log, + gemini_summary=gemini_summary, + system_prompt=system_prompt, + session_messages=session_messages, + respond_with_claude=respond_with_claude, + response_role=response_role, + ) + + +async def resume(checkpoint: OrchestrateCheckpoint, confirmed: bool) -> OrchestratorResult: + """Continue a job that was paused at a confirmation gate.""" + api_key = checkpoint.gemini_api_key or settings.gemini_api_key + client = genai.Client(api_key=api_key) + tool_declarations, tool_callables = get_tools_for_role(checkpoint.user_role) + + effective_confirm = (CONFIRM_REQUIRED - set(checkpoint.confirm_allow)) | set(checkpoint.confirm_deny) + + # Rebuild from saved state — strip "[awaiting confirmation]" placeholders + contents = list(checkpoint.pre_fn_state) + tool_call_log = [t for t in checkpoint.tool_call_log if t["result"] != "[awaiting confirmation]"] + + # Build function responses for this round + response_parts: list[types.Part] = [] + + for er in checkpoint.executed_results: + response_parts.append(types.Part(function_response=types.FunctionResponse( + name=er["name"], response={"result": er["result"]} + ))) + + for pt in checkpoint.pending_tools: + if confirmed: + result_str = await _execute_tool(pt["name"], pt["args"], tool_callables) + logger.info("Confirmed tool %s → %d chars", pt["name"], len(result_str)) + else: + result_str = "Action denied by user." + logger.info("Tool %s denied by user", pt["name"]) + tool_call_log.append({"tool": pt["name"], "args": pt["args"], "result": result_str}) + response_parts.append(types.Part(function_response=types.FunctionResponse( + name=pt["name"], response={"result": result_str} + ))) + + contents.append(types.Content(role="user", parts=response_parts)) + + gemini_summary, new_checkpoint = await _run_from_contents( + client=client, + contents=contents, + tool_declarations=tool_declarations, + tool_callables=tool_callables, + tool_call_log=tool_call_log, + effective_confirm=effective_confirm, + model_name=checkpoint.model_name, + task=checkpoint.task, + system_prompt=checkpoint.system_prompt, + session_messages=checkpoint.session_messages, + respond_with_claude=checkpoint.respond_with_claude, + response_role=checkpoint.response_role, + user_role=checkpoint.user_role, + confirm_allow=checkpoint.confirm_allow, + confirm_deny=checkpoint.confirm_deny, + starting_round=checkpoint.rounds_used, + gemini_api_key=api_key, + ) + + if new_checkpoint: + return OrchestratorResult( + response=gemini_summary, + tool_calls=list(tool_call_log), + backend="gemini", + gemini_summary=gemini_summary, + checkpoint=new_checkpoint, + ) + + return await _claude_handoff( + task=checkpoint.task, + tool_call_log=tool_call_log, + gemini_summary=gemini_summary, + system_prompt=checkpoint.system_prompt, + session_messages=checkpoint.session_messages, + respond_with_claude=checkpoint.respond_with_claude, + response_role=checkpoint.response_role, + ) + + +async def _run_from_contents( + client, + contents: list, + tool_declarations: list, + tool_callables: dict, + tool_call_log: list[dict], + effective_confirm: set[str], + model_name: str | None, + task: str, + system_prompt: str, + session_messages: list[dict] | None, + respond_with_claude: bool, + response_role: str, + user_role: str, + confirm_allow: frozenset, + confirm_deny: frozenset, + starting_round: int = 0, + gemini_api_key: str | None = None, +) -> tuple[str, OrchestrateCheckpoint | None]: + """ + Run the ReAct loop from the current contents state. + Returns (gemini_summary, checkpoint) — checkpoint is set if confirmation is needed. + """ gemini_summary = "" - # --- ReAct tool loop --- - for round_num in range(settings.orchestrator_max_rounds): + for round_num in range(starting_round, settings.orchestrator_max_rounds): logger.info("Orchestrator round %d for task: %.80s", round_num + 1, task) response = await asyncio.to_thread( @@ -113,67 +282,56 @@ async def run( candidate = response.candidates[0] parts = candidate.content.parts if candidate.content else [] - # Check if Gemini wants to call any tools tool_call_parts = [ p for p in parts if hasattr(p, "function_call") and p.function_call and p.function_call.name ] if not tool_call_parts: - # No more tool calls — extract Gemini's text summary gemini_summary = "".join( p.text for p in parts if hasattr(p, "text") and p.text ).strip() logger.info("Orchestrator done after %d round(s). Tools used: %d", round_num + 1, len(tool_call_log)) - break + return gemini_summary, None - # Add Gemini's response (with function calls) to the conversation contents.append(candidate.content) - # Execute tool calls — check confirmation requirement before calling + # Snapshot state before function responses — used if a checkpoint is needed + pre_fn_state = list(contents) + response_parts: list[types.Part] = [] - confirm_requested = False + pending_tools: list[dict] = [] + executed_results: list[dict] = [] for fc_part in tool_call_parts: fc = fc_part.function_call name = fc.name args = dict(fc.args) - if name in CONFIRM_REQUIRED: - args_str = json.dumps(args, indent=2) if args else "(no arguments)" - result_str = ( - f"⚠️ CONFIRMATION REQUIRED ⚠️\n" - f"Tool: {name}\nArguments:\n{args_str}\n\n" - f"Do NOT call this tool again. Tell the user exactly what you were " - f"about to do, explain the potential impact, and ask them to confirm " - f"by sending a follow-up message before you proceed." - ) - confirm_requested = True + if name in effective_confirm: + pending_tools.append({"name": name, "args": args}) logger.info("Tool %s blocked — confirmation required", name) else: result_str = await _execute_tool(name, args, tool_callables) logger.info("Tool %s → %d chars", name, len(result_str)) + executed_results.append({"name": name, "args": args, "result": result_str}) + tool_call_log.append({"tool": name, "args": args, "result": result_str}) + response_parts.append(types.Part(function_response=types.FunctionResponse( + name=name, response={"result": result_str} + ))) - tool_call_log.append({ - "tool": name, - "args": args, - "result": "[awaiting confirmation]" if name in CONFIRM_REQUIRED else result_str, - }) - response_parts.append( - types.Part( - function_response=types.FunctionResponse( - name=name, - response={"result": result_str}, - ) - ) - ) + if pending_tools: + # Add placeholder responses and get Gemini to produce the confirmation message + for pt in pending_tools: + placeholder = f"[AWAITING USER CONFIRMATION for {pt['name']}]" + response_parts.append(types.Part(function_response=types.FunctionResponse( + name=pt["name"], response={"result": placeholder} + ))) + tool_call_log.append({"tool": pt["name"], "args": pt["args"], "result": "[awaiting confirmation]"}) - contents.append(types.Content(role="user", parts=response_parts)) + contents.append(types.Content(role="user", parts=response_parts)) - if confirm_requested: - # Allow one more Gemini round to produce the confirmation-request message, - # then break — tool is not executed until user confirms in a follow-up. conf_response = await asyncio.to_thread( client.models.generate_content, model=model_name or settings.orchestrator_model, @@ -191,10 +349,30 @@ async def run( gemini_summary = "".join( p.text for p in conf_parts if hasattr(p, "text") and p.text ).strip() or "This action requires your explicit confirmation before it can proceed." - break + + checkpoint = OrchestrateCheckpoint( + engine="gemini", + pre_fn_state=pre_fn_state, + executed_results=executed_results, + pending_tools=pending_tools, + tool_call_log=list(tool_call_log), + task=task, + system_prompt=system_prompt, + session_messages=session_messages, + model_name=model_name, + gemini_api_key=gemini_api_key, + respond_with_claude=respond_with_claude, + response_role=response_role, + user_role=user_role, + confirm_allow=confirm_allow, + confirm_deny=confirm_deny, + rounds_used=round_num + 2, + ) + return gemini_summary, checkpoint + + contents.append(types.Content(role="user", parts=response_parts)) else: - # Hit the round limit — use whatever Gemini produced last logger.warning("Orchestrator hit max rounds (%d)", settings.orchestrator_max_rounds) gemini_summary = ( f"Reached the tool iteration limit ({settings.orchestrator_max_rounds} rounds). " @@ -202,21 +380,28 @@ async def run( + "\n\n".join(f"**{t['tool']}**: {t['result'][:500]}" for t in tool_call_log) ) - # --- Claude handoff --- + return gemini_summary, None + + +async def _claude_handoff( + task: str, + tool_call_log: list[dict], + gemini_summary: str, + system_prompt: str, + session_messages: list[dict] | None, + respond_with_claude: bool, + response_role: str, +) -> OrchestratorResult: if respond_with_claude: claude_prompt = _build_claude_prompt(task, tool_call_log, gemini_summary) - - # Merge with session history so Claude has conversation context messages = list(session_messages or []) messages.append({"role": "user", "content": claude_prompt}) - response_text, backend = await complete( system_prompt=system_prompt, messages=messages, role=response_role, ) else: - # Cron/background tasks: return Gemini's summary directly, no Claude call response_text = gemini_summary or "No information gathered." backend = "gemini" @@ -242,12 +427,11 @@ def _build_task_prompt(task: str, session_messages: list[dict] | None) -> str: if not session_messages: return task - # Include last few turns for context (don't send the full history to keep tokens low) - recent = session_messages[-6:] # last 3 turns + recent = session_messages[-6:] history_lines = [] for msg in recent: label = "User" if msg["role"] == "user" else "Assistant" - history_lines.append(f"{label}: {msg['content'][:300]}") # truncate long messages + history_lines.append(f"{label}: {msg['content'][:300]}") context = "\n".join(history_lines) return f"\n{context}\n\n\nCurrent request: {task}" @@ -265,7 +449,6 @@ def _build_claude_prompt( parts.append("## Research gathered\n") for tc in tool_calls: parts.append(f"### {tc['tool']}({_format_args(tc['args'])})") - # Truncate very long results — Claude gets the gist result = tc["result"] if len(result) > 2000: result = result[:2000] + "\n… [truncated]" diff --git a/cortex/routers/orchestrator.py b/cortex/routers/orchestrator.py index e6a8baa..753c380 100644 --- a/cortex/routers/orchestrator.py +++ b/cortex/routers/orchestrator.py @@ -15,10 +15,10 @@ import logging import uuid from datetime import datetime, timezone -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from auth_utils import get_user_gemini_key, get_user_role +from auth_utils import get_user_gemini_key, get_user_role, get_tool_policy from config import settings from context_loader import load_context from persona import set_context, validate as validate_persona @@ -31,12 +31,16 @@ router = APIRouter(prefix="/orchestrate", tags=["orchestrator"]) # --------------------------------------------------------------------------- # In-memory job store -# Jobs are keyed by UUID. For this phase, memory is fine — jobs are short-lived. # --------------------------------------------------------------------------- _jobs: dict[str, dict] = {} _jobs_lock = asyncio.Lock() +# Checkpoints are stored separately — they hold Python objects (types.Content, etc.) +# that can't be included in the JSON-serializable job dict. +_checkpoints: dict[str, orchestrator_engine.OrchestrateCheckpoint] = {} +_checkpoints_lock = asyncio.Lock() + # --------------------------------------------------------------------------- # Request / response models @@ -57,7 +61,7 @@ class OrchestrateRequest(BaseModel): class OrchestrateResponse(BaseModel): job_id: str - status: str # "queued" | "running" | "complete" | "error" + status: str # "queued" | "running" | "complete" | "error" | "awaiting_confirmation" class JobStatusResponse(BaseModel): @@ -72,6 +76,7 @@ class JobStatusResponse(BaseModel): backend: str | None = None gemini_summary: str | None = None error: str | None = None + pending_confirmation: dict | None = None # {tools: [{name, args}], message: str} # --------------------------------------------------------------------------- @@ -85,7 +90,6 @@ async def orchestrate(req: OrchestrateRequest) -> OrchestrateResponse: user, persona = validate_persona(req.user, req.persona) set_context(user, persona) except ValueError as e: - from fastapi import HTTPException raise HTTPException(status_code=400, detail=str(e)) job_id = str(uuid.uuid4()) @@ -97,17 +101,19 @@ async def orchestrate(req: OrchestrateRequest) -> OrchestrateResponse: "task": req.task, "created_at": now, "completed_at": None, + "session_id": None, "response": None, "tool_calls": None, "backend": None, "gemini_summary": None, "error": None, + "pending_confirmation": None, + "_user": user, } async with _jobs_lock: _jobs[job_id] = job - # Run in background — caller polls GET /orchestrate/{job_id} asyncio.create_task(_run_job(job_id, req, user)) logger.info("Orchestrator job queued: %s — %.80s", job_id, req.task) return OrchestrateResponse(job_id=job_id, status="queued") @@ -120,10 +126,9 @@ async def job_status(job_id: str) -> JobStatusResponse: job = _jobs.get(job_id) if job is None: - from fastapi import HTTPException raise HTTPException(status_code=404, detail=f"Job {job_id} not found") - return JobStatusResponse(**job) + return JobStatusResponse(**{k: v for k, v in job.items() if not k.startswith("_")}) @router.get("", response_model=list[JobStatusResponse]) @@ -131,11 +136,55 @@ async def list_jobs() -> list[JobStatusResponse]: """List all jobs (most recent first). Useful for debugging.""" async with _jobs_lock: jobs = sorted(_jobs.values(), key=lambda j: j["created_at"], reverse=True) - return [JobStatusResponse(**j) for j in jobs] + return [JobStatusResponse(**{k: v for k, v in j.items() if not k.startswith("_")}) for j in jobs] + + +@router.post("/{job_id}/confirm", response_model=OrchestrateResponse) +async def confirm_job(job_id: str) -> OrchestrateResponse: + """Confirm a pending tool call — the blocked tool will execute and the job continues.""" + async with _checkpoints_lock: + checkpoint = _checkpoints.pop(job_id, None) + + if checkpoint is None: + raise HTTPException(status_code=404, detail="No pending confirmation for this job") + + async with _jobs_lock: + job = _jobs.get(job_id) + if not job or job["status"] != "awaiting_confirmation": + raise HTTPException(status_code=409, detail="Job is not awaiting confirmation") + _jobs[job_id]["status"] = "running" + _jobs[job_id]["pending_confirmation"] = None + user = job.get("_user", "scott") + + asyncio.create_task(_resume_job(job_id, checkpoint, confirmed=True, user=user)) + logger.info("Orchestrator job %s confirmed — resuming", job_id) + return OrchestrateResponse(job_id=job_id, status="running") + + +@router.post("/{job_id}/deny", response_model=OrchestrateResponse) +async def deny_job(job_id: str) -> OrchestrateResponse: + """Deny a pending tool call — the tool is skipped and the job produces a final response.""" + async with _checkpoints_lock: + checkpoint = _checkpoints.pop(job_id, None) + + if checkpoint is None: + raise HTTPException(status_code=404, detail="No pending confirmation for this job") + + async with _jobs_lock: + job = _jobs.get(job_id) + if not job or job["status"] != "awaiting_confirmation": + raise HTTPException(status_code=409, detail="Job is not awaiting confirmation") + _jobs[job_id]["status"] = "running" + _jobs[job_id]["pending_confirmation"] = None + user = job.get("_user", "scott") + + asyncio.create_task(_resume_job(job_id, checkpoint, confirmed=False, user=user)) + logger.info("Orchestrator job %s denied — resuming with skip", job_id) + return OrchestrateResponse(job_id=job_id, status="running") # --------------------------------------------------------------------------- -# Background runner +# Background runners # --------------------------------------------------------------------------- async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: @@ -146,7 +195,6 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: try: from session_store import load as load_session, save as save_session, generate_session_id - # Load Inara's system prompt (same as the chat router does) tier = req.tier or settings.default_tier system_prompt = load_context( tier, @@ -155,16 +203,17 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: include_short=req.include_short, ) - # Load session history if a session_id was provided session_id = req.session_id or generate_session_id() history = load_session(session_id) session_messages = history or None - # Choose engine based on the orchestrator role in the model registry orch_model = model_registry.get_model_for_role(user, "orchestrator") - user_role = get_user_role(user) + policy = get_tool_policy(user) + confirm_allow = set(policy.get("allow", [])) + confirm_deny = set(policy.get("deny", [])) + if orch_model and orch_model.get("type") == "local_openai": result = await openai_orchestrator.run( task=req.task, @@ -173,10 +222,10 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: model_cfg=orch_model, respond_with_final=req.respond_with_claude, user_role=user_role, + confirm_allow=confirm_allow, + confirm_deny=confirm_deny, ) else: - # Use the API key embedded in the resolved model config (V2 registry with - # account_id), then fall back to the per-user key from auth.json, then .env. gemini_key = ( (orch_model.get("api_key") if orch_model else None) or get_user_gemini_key(user) @@ -190,28 +239,31 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: model_name=orch_model.get("model_name") if orch_model else None, response_role=req.chat_role, user_role=user_role, + confirm_allow=confirm_allow, + confirm_deny=confirm_deny, ) - # Save the turn to the session store so it survives a page refresh - history.append({"role": "user", "content": req.task}) - history.append({"role": "assistant", "content": result.response}) - save_session(session_id, history) + if result.checkpoint: + async with _checkpoints_lock: + _checkpoints[job_id] = result.checkpoint + async with _jobs_lock: + _jobs[job_id].update({ + "status": "awaiting_confirmation", + "response": result.response, + "tool_calls": result.tool_calls, + "backend": result.backend, + "gemini_summary": result.gemini_summary, + "session_id": session_id, + "pending_confirmation": { + "tools": result.checkpoint.pending_tools, + "message": result.response, + }, + }) + logger.info("Orchestrator job %s awaiting confirmation — %d tool(s) blocked", + job_id, len(result.checkpoint.pending_tools)) + return - from session_logger import log_turn - log_turn(session_id, req.task, result.response) - - now = datetime.now(timezone.utc).isoformat() - async with _jobs_lock: - _jobs[job_id].update({ - "status": "complete", - "completed_at": now, - "session_id": session_id, - "response": result.response, - "tool_calls": result.tool_calls, - "backend": result.backend, - "gemini_summary": result.gemini_summary, - }) - logger.info("Orchestrator job complete: %s (%d tool calls)", job_id, len(result.tool_calls)) + await _finalize_job(job_id, result, session_id, req.task, history) except Exception as e: logger.exception("Orchestrator job failed: %s", job_id) @@ -222,3 +274,87 @@ async def _run_job(job_id: str, req: OrchestrateRequest, user: str) -> None: "completed_at": now, "error": str(e), }) + + +async def _resume_job( + job_id: str, + checkpoint: orchestrator_engine.OrchestrateCheckpoint, + confirmed: bool, + user: str, +) -> None: + """Resume a job after the user confirms or denies a pending tool call.""" + try: + if checkpoint.engine == "gemini": + result = await orchestrator_engine.resume(checkpoint, confirmed) + else: + result = await openai_orchestrator.resume(checkpoint, confirmed) + + if result.checkpoint: + # Another confirmation needed (chained gates) + async with _checkpoints_lock: + _checkpoints[job_id] = result.checkpoint + async with _jobs_lock: + _jobs[job_id].update({ + "status": "awaiting_confirmation", + "response": result.response, + "tool_calls": result.tool_calls, + "backend": result.backend, + "gemini_summary": result.gemini_summary, + "pending_confirmation": { + "tools": result.checkpoint.pending_tools, + "message": result.response, + }, + }) + logger.info("Orchestrator job %s awaiting another confirmation", job_id) + return + + async with _jobs_lock: + session_id = _jobs[job_id].get("session_id") or "" + task = _jobs[job_id].get("task", "") + + from session_store import load as load_session + history = load_session(session_id) if session_id else [] + await _finalize_job(job_id, result, session_id, task, history) + + except Exception as e: + logger.exception("Orchestrator resume failed: %s", job_id) + now = datetime.now(timezone.utc).isoformat() + async with _jobs_lock: + _jobs[job_id].update({ + "status": "error", + "completed_at": now, + "error": str(e), + }) + + +async def _finalize_job( + job_id: str, + result: orchestrator_engine.OrchestratorResult, + session_id: str, + task: str, + history: list, +) -> None: + """Save session, log the turn, and mark the job complete.""" + from session_store import save as save_session, generate_session_id + from session_logger import log_turn + + if not session_id: + session_id = generate_session_id() + + history.append({"role": "user", "content": task}) + history.append({"role": "assistant", "content": result.response}) + save_session(session_id, history) + log_turn(session_id, task, result.response) + + now = datetime.now(timezone.utc).isoformat() + async with _jobs_lock: + _jobs[job_id].update({ + "status": "complete", + "completed_at": now, + "session_id": session_id, + "response": result.response, + "tool_calls": result.tool_calls, + "backend": result.backend, + "gemini_summary": result.gemini_summary, + }) + logger.info("Orchestrator job complete: %s (%d tool calls)", job_id, len(result.tool_calls)) diff --git a/cortex/routers/settings.py b/cortex/routers/settings.py index 68cb033..d6f43ab 100644 --- a/cortex/routers/settings.py +++ b/cortex/routers/settings.py @@ -18,7 +18,8 @@ import jwt from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse -from auth_utils import COOKIE_NAME, decode_token, check_credentials, set_password, _read_auth, _write_auth, get_user_channels +from auth_utils import COOKIE_NAME, decode_token, check_credentials, set_password, _read_auth, _write_auth, get_user_channels, get_tool_policy, save_tool_policy +from tools import CONFIRM_REQUIRED from persona import list_user_personas from config import settings as app_settings @@ -84,6 +85,15 @@ def _settings_page(username: str, personas: list[str], back_persona: str = "", s html = html.replace("{{ nc_notify_room }}", nc_room) html = html.replace("{{ gc_webhook }}", gc_webhook) + # Tool permission policy + policy = get_tool_policy(username) + tool_allow_text = _html.escape("\n".join(policy.get("allow", []))) + tool_deny_text = _html.escape("\n".join(policy.get("deny", []))) + confirm_tools_list = _html.escape(", ".join(sorted(CONFIRM_REQUIRED))) + html = html.replace("{{ tool_allow }}", tool_allow_text) + html = html.replace("{{ tool_deny }}", tool_deny_text) + html = html.replace("{{ confirm_required_tools }}", confirm_tools_list) + persona_items = "\n".join( f'''
  • {p} @@ -302,6 +312,27 @@ async def save_notifications( success="Notification settings saved.")) +@router.post("/settings/tool-policy", include_in_schema=False) +async def save_tool_policy_route( + request: Request, + allow_list: str = Form(""), + deny_list: str = Form(""), +): + username = _get_session_user(request) + if not username: + return RedirectResponse("/login", status_code=302) + + personas = list_user_personas(username) + back_persona = _preferred_persona(request, username) + + allow_tools = [ln.strip() for ln in allow_list.splitlines() if ln.strip()] + deny_tools = [ln.strip() for ln in deny_list.splitlines() if ln.strip()] + save_tool_policy(username, {"allow": allow_tools, "deny": deny_tools}) + logger.info("tool policy updated for %s (allow=%d deny=%d)", username, len(allow_tools), len(deny_tools)) + return HTMLResponse(_settings_page(username, personas, back_persona, + success="Tool permission policy saved.")) + + @router.post("/settings/email-allowlist", include_in_schema=False) async def save_email_allowlist( request: Request, diff --git a/cortex/static/app.js b/cortex/static/app.js index f3843e7..724a110 100644 --- a/cortex/static/app.js +++ b/cortex/static/app.js @@ -1192,6 +1192,37 @@ : '⚡ working…'; continue; } + + if (job.status === 'awaiting_confirmation') { + const pc = job.pending_confirmation || {}; + const toolNames = (pc.tools || []).map(t => t.name).join(', '); + thinkingDiv.className = 'message assistant'; + thinkingDiv.innerHTML = `
    +

    ${escapeHtml(pc.message || 'Confirm this action?')}

    +

    Tool${(pc.tools||[]).length !== 1 ? 's' : ''}: ${escapeHtml(toolNames)}

    +
    + + +
    +
    `; + + const confirmed = await new Promise(resolve => { + thinkingDiv.querySelector('.confirm-btn').onclick = () => resolve(true); + thinkingDiv.querySelector('.deny-btn').onclick = () => resolve(false); + }); + + thinkingDiv.className = 'message assistant thinking'; + thinkingDiv.textContent = confirmed ? '⚡ confirmed — continuing…' : '⚡ denied — finishing…'; + + const action = confirmed ? 'confirm' : 'deny'; + const resumeRes = await fetch(`/orchestrate/${job_id}/${action}`, { + method: 'POST', + signal: activeController.signal, + }); + if (!resumeRes.ok) throw new Error(`Resume failed: HTTP ${resumeRes.status}`); + continue; + } + break; } diff --git a/cortex/static/settings.html b/cortex/static/settings.html index 98f8011..3253a90 100644 --- a/cortex/static/settings.html +++ b/cortex/static/settings.html @@ -393,6 +393,35 @@ + +
    +

    Tool Permissions

    +

    + Override the default confirmation gate for orchestrator tools. + Allow list — tools that run without asking for confirmation. + Deny list — tools that are always blocked for your account. + One tool name per line. +

    +

    + Tools requiring confirmation by default: {{ confirm_required_tools }} +

    +
    +
    + + +
    +
    + + +
    + +
    +
    +

    Browser Cache

    diff --git a/cortex/static/style.css b/cortex/static/style.css index a477963..3948bf3 100644 --- a/cortex/static/style.css +++ b/cortex/static/style.css @@ -546,6 +546,25 @@ .message.thinking { color: var(--muted); font-style: italic; } + /* Confirmation gate */ + .confirm-gate { display: flex; flex-direction: column; gap: 0.6rem; } + .confirm-gate p { margin: 0; } + .confirm-tools { font-size: 0.82rem; color: var(--muted); } + .confirm-actions { display: flex; gap: 0.5rem; margin-top: 0.25rem; } + .confirm-btn, .deny-btn { + padding: 0.35rem 0.9rem; + border-radius: 6px; + border: none; + font-size: 0.85rem; + font-weight: 600; + cursor: pointer; + transition: opacity 0.15s; + } + .confirm-btn { background: #16a34a; color: #fff; } + .confirm-btn:hover { opacity: 0.85; } + .deny-btn { background: var(--surface); border: 1px solid var(--border); color: var(--text); } + .deny-btn:hover { border-color: var(--muted); } + /* Copy button */ .message.assistant, .message.user { position: relative; }