diff --git a/cortex/openai_orchestrator.py b/cortex/openai_orchestrator.py index dcf8282..881fc57 100644 --- a/cortex/openai_orchestrator.py +++ b/cortex/openai_orchestrator.py @@ -190,6 +190,66 @@ async def resume(checkpoint: OrchestrateCheckpoint, confirmed: bool) -> Orchestr ) +_CHARS_PER_TOKEN = 4 +# Fixed token overhead budget for sending 40 tool schemas per call +_TOOL_SCHEMA_OVERHEAD = 3_000 +# Chars to keep per truncated old tool result +_TRUNC_RESULT_CHARS = 400 +# Always keep the last N tool-result messages uncompacted +_KEEP_RECENT_TOOL_MSGS = 6 # ~2 rounds of 3 tools each + + +def _estimate_tokens(messages: list[dict]) -> int: + total = sum(len(json.dumps(m)) for m in messages) + return total // _CHARS_PER_TOKEN + _TOOL_SCHEMA_OVERHEAD + + +def _compact_messages(messages: list[dict], budget_tokens: int) -> list[dict]: + """ + Truncate old tool result content when approaching the context budget. + + Strategy: keep system message, recent assistant/tool rounds, and the + original user task intact. Truncate content of old tool results in the + middle of the conversation — the model only needs recent results to reason. + """ + if _estimate_tokens(messages) <= budget_tokens: + return messages + + tool_indices = [i for i, m in enumerate(messages) if m.get("role") == "tool"] + n_to_compact = max(0, len(tool_indices) - _KEEP_RECENT_TOOL_MSGS) + if n_to_compact == 0: + return messages # nothing old enough to trim + + compact_set = set(tool_indices[:n_to_compact]) + result = [] + chars_saved = 0 + for i, msg in enumerate(messages): + if i in compact_set: + content = msg.get("content", "") + if len(content) > _TRUNC_RESULT_CHARS: + msg = dict(msg) + saved = len(content) - _TRUNC_RESULT_CHARS + chars_saved += saved + msg["content"] = ( + content[:_TRUNC_RESULT_CHARS] + + f" …[{len(content) - _TRUNC_RESULT_CHARS} chars omitted]" + ) + result.append(msg) + + new_est = _estimate_tokens(result) + logger.info( + "context compaction: saved ~%d tokens (%d chars), now ~%d / %d tokens", + chars_saved // _CHARS_PER_TOKEN, chars_saved, new_est, budget_tokens, + ) + return result + + +def _context_budget(model_cfg: dict | None) -> int: + """Return the usable token budget (75% of context window, min 16k, default 32k).""" + context_k = (model_cfg or {}).get("context_k") or 32 + return max(16_000, int(context_k * 1000 * 0.75)) + + async def _run_from_messages( client, messages: list[dict], @@ -211,10 +271,13 @@ async def _run_from_messages( Returns (final_response, checkpoint) — checkpoint is set if confirmation is needed. """ final_response = "" + budget = _context_budget(model_cfg) for round_num in range(starting_round, settings.orchestrator_max_rounds): - logger.info("OpenAI orchestrator round %d / %d model=%s", - round_num + 1, settings.orchestrator_max_rounds, model_name) + messages = _compact_messages(messages, budget) + est = _estimate_tokens(messages) + logger.info("OpenAI orchestrator round %d / %d model=%s ~%d tokens", + round_num + 1, settings.orchestrator_max_rounds, model_name, est) response = await client.chat.completions.create( model=model_name, @@ -240,7 +303,8 @@ async def _run_from_messages( ] messages.append(assistant_msg) - if choice.finish_reason == "tool_calls" and msg.tool_calls: + # Some models set finish_reason="stop" even when tool_calls are present + if msg.tool_calls and (choice.finish_reason in ("tool_calls", "stop", None)): # Snapshot state before tool responses for potential checkpoint pre_fn_state = list(messages) @@ -249,10 +313,14 @@ async def _run_from_messages( for tc in msg.tool_calls: name = tc.function.name + raw_args = tc.function.arguments or "{}" try: - args_parsed = json.loads(tc.function.arguments) - except json.JSONDecodeError: - args_parsed = {"raw": tc.function.arguments} + args_parsed = json.loads(raw_args) + if not isinstance(args_parsed, dict): + raise ValueError("args must be a JSON object") + except (json.JSONDecodeError, ValueError) as e: + logger.warning("Malformed tool args for %s: %s — args: %.200s", name, e, raw_args) + args_parsed = {} if name in effective_confirm: pending_tools.append({"name": name, "args": args_parsed, "tool_call_id": tc.id})