feat: context budget enforcement + compaction in OpenAI orchestrator
Protects all models in the Primary/Backup chain regardless of context window:
- _context_budget(): 75% of model_cfg["context_k"] * 1000 (default 32k if unset)
- _estimate_tokens(): char count / 4 + 3k overhead for tool schemas
- _compact_messages(): truncates old tool results to 400 chars, keeps last 6
intact (~2 recent rounds), logs chars saved per compaction pass
- Compaction runs before every API call; log line now shows estimated token count
- Malformed tool call args logged with model/args detail instead of silent {}
- finish_reason check accepts "stop" and None alongside "tool_calls" (some
models return wrong reason even when tool_calls are present)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -190,6 +190,66 @@ async def resume(checkpoint: OrchestrateCheckpoint, confirmed: bool) -> Orchestr
|
||||
)
|
||||
|
||||
|
||||
_CHARS_PER_TOKEN = 4
|
||||
# Fixed token overhead budget for sending 40 tool schemas per call
|
||||
_TOOL_SCHEMA_OVERHEAD = 3_000
|
||||
# Chars to keep per truncated old tool result
|
||||
_TRUNC_RESULT_CHARS = 400
|
||||
# Always keep the last N tool-result messages uncompacted
|
||||
_KEEP_RECENT_TOOL_MSGS = 6 # ~2 rounds of 3 tools each
|
||||
|
||||
|
||||
def _estimate_tokens(messages: list[dict]) -> int:
|
||||
total = sum(len(json.dumps(m)) for m in messages)
|
||||
return total // _CHARS_PER_TOKEN + _TOOL_SCHEMA_OVERHEAD
|
||||
|
||||
|
||||
def _compact_messages(messages: list[dict], budget_tokens: int) -> list[dict]:
|
||||
"""
|
||||
Truncate old tool result content when approaching the context budget.
|
||||
|
||||
Strategy: keep system message, recent assistant/tool rounds, and the
|
||||
original user task intact. Truncate content of old tool results in the
|
||||
middle of the conversation — the model only needs recent results to reason.
|
||||
"""
|
||||
if _estimate_tokens(messages) <= budget_tokens:
|
||||
return messages
|
||||
|
||||
tool_indices = [i for i, m in enumerate(messages) if m.get("role") == "tool"]
|
||||
n_to_compact = max(0, len(tool_indices) - _KEEP_RECENT_TOOL_MSGS)
|
||||
if n_to_compact == 0:
|
||||
return messages # nothing old enough to trim
|
||||
|
||||
compact_set = set(tool_indices[:n_to_compact])
|
||||
result = []
|
||||
chars_saved = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if i in compact_set:
|
||||
content = msg.get("content", "")
|
||||
if len(content) > _TRUNC_RESULT_CHARS:
|
||||
msg = dict(msg)
|
||||
saved = len(content) - _TRUNC_RESULT_CHARS
|
||||
chars_saved += saved
|
||||
msg["content"] = (
|
||||
content[:_TRUNC_RESULT_CHARS]
|
||||
+ f" …[{len(content) - _TRUNC_RESULT_CHARS} chars omitted]"
|
||||
)
|
||||
result.append(msg)
|
||||
|
||||
new_est = _estimate_tokens(result)
|
||||
logger.info(
|
||||
"context compaction: saved ~%d tokens (%d chars), now ~%d / %d tokens",
|
||||
chars_saved // _CHARS_PER_TOKEN, chars_saved, new_est, budget_tokens,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _context_budget(model_cfg: dict | None) -> int:
|
||||
"""Return the usable token budget (75% of context window, min 16k, default 32k)."""
|
||||
context_k = (model_cfg or {}).get("context_k") or 32
|
||||
return max(16_000, int(context_k * 1000 * 0.75))
|
||||
|
||||
|
||||
async def _run_from_messages(
|
||||
client,
|
||||
messages: list[dict],
|
||||
@@ -211,10 +271,13 @@ async def _run_from_messages(
|
||||
Returns (final_response, checkpoint) — checkpoint is set if confirmation is needed.
|
||||
"""
|
||||
final_response = ""
|
||||
budget = _context_budget(model_cfg)
|
||||
|
||||
for round_num in range(starting_round, settings.orchestrator_max_rounds):
|
||||
logger.info("OpenAI orchestrator round %d / %d model=%s",
|
||||
round_num + 1, settings.orchestrator_max_rounds, model_name)
|
||||
messages = _compact_messages(messages, budget)
|
||||
est = _estimate_tokens(messages)
|
||||
logger.info("OpenAI orchestrator round %d / %d model=%s ~%d tokens",
|
||||
round_num + 1, settings.orchestrator_max_rounds, model_name, est)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
@@ -240,7 +303,8 @@ async def _run_from_messages(
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
|
||||
if choice.finish_reason == "tool_calls" and msg.tool_calls:
|
||||
# Some models set finish_reason="stop" even when tool_calls are present
|
||||
if msg.tool_calls and (choice.finish_reason in ("tool_calls", "stop", None)):
|
||||
# Snapshot state before tool responses for potential checkpoint
|
||||
pre_fn_state = list(messages)
|
||||
|
||||
@@ -249,10 +313,14 @@ async def _run_from_messages(
|
||||
|
||||
for tc in msg.tool_calls:
|
||||
name = tc.function.name
|
||||
raw_args = tc.function.arguments or "{}"
|
||||
try:
|
||||
args_parsed = json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args_parsed = {"raw": tc.function.arguments}
|
||||
args_parsed = json.loads(raw_args)
|
||||
if not isinstance(args_parsed, dict):
|
||||
raise ValueError("args must be a JSON object")
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning("Malformed tool args for %s: %s — args: %.200s", name, e, raw_args)
|
||||
args_parsed = {}
|
||||
|
||||
if name in effective_confirm:
|
||||
pending_tools.append({"name": name, "args": args_parsed, "tool_call_id": tc.id})
|
||||
|
||||
Reference in New Issue
Block a user