feat: context budget enforcement + compaction in OpenAI orchestrator

Protects all models in the Primary/Backup chain regardless of context window:
- _context_budget(): 75% of model_cfg["context_k"] * 1000 (default 32k if unset)
- _estimate_tokens(): char count / 4 + 3k overhead for tool schemas
- _compact_messages(): truncates old tool results to 400 chars, keeps last 6
  intact (~2 recent rounds), logs chars saved per compaction pass
- Compaction runs before every API call; log line now shows estimated token count
- Malformed tool call args logged with model/args detail instead of silent {}
- finish_reason check accepts "stop" and None alongside "tool_calls" (some
  models return wrong reason even when tool_calls are present)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Scott Idem
2026-05-05 22:01:54 -04:00
parent 7d221863dc
commit a75546485b

View File

@@ -190,6 +190,66 @@ async def resume(checkpoint: OrchestrateCheckpoint, confirmed: bool) -> Orchestr
)
_CHARS_PER_TOKEN = 4
# Fixed token overhead budget for sending 40 tool schemas per call
_TOOL_SCHEMA_OVERHEAD = 3_000
# Chars to keep per truncated old tool result
_TRUNC_RESULT_CHARS = 400
# Always keep the last N tool-result messages uncompacted
_KEEP_RECENT_TOOL_MSGS = 6 # ~2 rounds of 3 tools each
def _estimate_tokens(messages: list[dict]) -> int:
total = sum(len(json.dumps(m)) for m in messages)
return total // _CHARS_PER_TOKEN + _TOOL_SCHEMA_OVERHEAD
def _compact_messages(messages: list[dict], budget_tokens: int) -> list[dict]:
"""
Truncate old tool result content when approaching the context budget.
Strategy: keep system message, recent assistant/tool rounds, and the
original user task intact. Truncate content of old tool results in the
middle of the conversation — the model only needs recent results to reason.
"""
if _estimate_tokens(messages) <= budget_tokens:
return messages
tool_indices = [i for i, m in enumerate(messages) if m.get("role") == "tool"]
n_to_compact = max(0, len(tool_indices) - _KEEP_RECENT_TOOL_MSGS)
if n_to_compact == 0:
return messages # nothing old enough to trim
compact_set = set(tool_indices[:n_to_compact])
result = []
chars_saved = 0
for i, msg in enumerate(messages):
if i in compact_set:
content = msg.get("content", "")
if len(content) > _TRUNC_RESULT_CHARS:
msg = dict(msg)
saved = len(content) - _TRUNC_RESULT_CHARS
chars_saved += saved
msg["content"] = (
content[:_TRUNC_RESULT_CHARS]
+ f" …[{len(content) - _TRUNC_RESULT_CHARS} chars omitted]"
)
result.append(msg)
new_est = _estimate_tokens(result)
logger.info(
"context compaction: saved ~%d tokens (%d chars), now ~%d / %d tokens",
chars_saved // _CHARS_PER_TOKEN, chars_saved, new_est, budget_tokens,
)
return result
def _context_budget(model_cfg: dict | None) -> int:
"""Return the usable token budget (75% of context window, min 16k, default 32k)."""
context_k = (model_cfg or {}).get("context_k") or 32
return max(16_000, int(context_k * 1000 * 0.75))
async def _run_from_messages(
client,
messages: list[dict],
@@ -211,10 +271,13 @@ async def _run_from_messages(
Returns (final_response, checkpoint) — checkpoint is set if confirmation is needed.
"""
final_response = ""
budget = _context_budget(model_cfg)
for round_num in range(starting_round, settings.orchestrator_max_rounds):
logger.info("OpenAI orchestrator round %d / %d model=%s",
round_num + 1, settings.orchestrator_max_rounds, model_name)
messages = _compact_messages(messages, budget)
est = _estimate_tokens(messages)
logger.info("OpenAI orchestrator round %d / %d model=%s ~%d tokens",
round_num + 1, settings.orchestrator_max_rounds, model_name, est)
response = await client.chat.completions.create(
model=model_name,
@@ -240,7 +303,8 @@ async def _run_from_messages(
]
messages.append(assistant_msg)
if choice.finish_reason == "tool_calls" and msg.tool_calls:
# Some models set finish_reason="stop" even when tool_calls are present
if msg.tool_calls and (choice.finish_reason in ("tool_calls", "stop", None)):
# Snapshot state before tool responses for potential checkpoint
pre_fn_state = list(messages)
@@ -249,10 +313,14 @@ async def _run_from_messages(
for tc in msg.tool_calls:
name = tc.function.name
raw_args = tc.function.arguments or "{}"
try:
args_parsed = json.loads(tc.function.arguments)
except json.JSONDecodeError:
args_parsed = {"raw": tc.function.arguments}
args_parsed = json.loads(raw_args)
if not isinstance(args_parsed, dict):
raise ValueError("args must be a JSON object")
except (json.JSONDecodeError, ValueError) as e:
logger.warning("Malformed tool args for %s: %s — args: %.200s", name, e, raw_args)
args_parsed = {}
if name in effective_confirm:
pending_tools.append({"name": name, "args": args_parsed, "tool_call_id": tc.id})