Skip to content

Commit e71aaa6

Browse files
authored
Merge pull request #923 from inclusionAI/tool_match
Handle incomplete tool-call history during replay
2 parents 97d32b6 + 595e3c7 commit e71aaa6

2 files changed

Lines changed: 139 additions & 37 deletions

File tree

aworld/agents/llm_agent.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -865,17 +865,45 @@ async def async_messages_transform(self,
865865
tool_calls_map = {}
866866
last_tool_calls = []
867867
matched_tool_call_ids = set()
868+
869+
def _is_tool_history(history) -> bool:
870+
if isinstance(history, MemoryMessage):
871+
return isinstance(history, MemoryToolMessage)
872+
return history.metadata.get('role') == 'tool'
873+
874+
def _drop_incomplete_tool_call_turn(reason: str):
875+
nonlocal tool_calls_map, last_tool_calls
876+
if not last_tool_calls:
877+
return
878+
dropped_message = None
879+
if messages and messages[-1].get("role") == "assistant" and messages[-1].get("tool_calls"):
880+
dropped_message = messages.pop()
881+
logger.warning(
882+
"Skip incomplete tool-call turn in memory replay: "
883+
f"reason={reason}, missing_tool_call_ids={last_tool_calls}, "
884+
f"matched_tool_result_ids={list(tool_calls_map.keys())}, "
885+
f"dropped_assistant_message={bool(dropped_message)}, agent={self.id()}"
886+
)
887+
tool_calls_map = {}
888+
last_tool_calls = []
889+
890+
def _append_complete_tool_results():
891+
nonlocal tool_calls_map, last_tool_calls
892+
for tool_call_id in last_tool_calls:
893+
if tool_call_id not in tool_calls_map:
894+
_drop_incomplete_tool_call_turn(f"missing tool result for {tool_call_id}")
895+
return
896+
messages.append(tool_calls_map.get(tool_call_id))
897+
matched_tool_call_ids.add(tool_call_id)
898+
tool_calls_map = {}
899+
last_tool_calls = []
900+
868901
for history in histories:
869902
if len(last_tool_calls) > 0 and len(tool_calls_map) == len(last_tool_calls):
870903
# Maintain the order of tool calls
871-
for tool_call_id in last_tool_calls:
872-
if tool_call_id not in tool_calls_map:
873-
raise AWorldRuntimeException(
874-
f"tool_calls mismatch! {tool_call_id} not found in {tool_calls_map}, last_tool_calls: {last_tool_calls}, messages: {messages}, histories: {histories}")
875-
messages.append(tool_calls_map.get(tool_call_id))
876-
matched_tool_call_ids.add(tool_call_id)
877-
tool_calls_map = {}
878-
last_tool_calls = []
904+
_append_complete_tool_results()
905+
elif last_tool_calls and not _is_tool_history(history):
906+
_drop_incomplete_tool_call_turn("next non-tool message encountered")
879907

880908
if isinstance(history, MemoryMessage):
881909
if isinstance(history, MemoryToolMessage):
@@ -887,9 +915,10 @@ async def async_messages_transform(self,
887915
f"tool_call_id={history.tool_call_id}, agent={self.id()}"
888916
)
889917
else:
890-
raise AWorldRuntimeException(
891-
f"tool_calls mismatch! {history.tool_call_id} not found in {last_tool_calls}, "
892-
f"messages: {messages}, histories: {histories}")
918+
logger.warning(
919+
"Skip orphan tool result in memory replay: "
920+
f"tool_call_id={history.tool_call_id}, agent={self.id()}"
921+
)
893922
else:
894923
messages.append(history.to_openai_message())
895924
if isinstance(history, MemoryAIMessage) and history.tool_calls:
@@ -908,9 +937,10 @@ async def async_messages_transform(self,
908937
f"tool_call_id={tool_call_id}, agent={self.id()}"
909938
)
910939
else:
911-
raise AWorldRuntimeException(
912-
f"tool_calls mismatch! {tool_call_id} not found in {last_tool_calls}, "
913-
f"messages: {messages}, histories: {histories}")
940+
logger.warning(
941+
"Skip orphan tool result in memory replay: "
942+
f"tool_call_id={tool_call_id}, agent={self.id()}"
943+
)
914944
else:
915945
if not self.use_tools_in_prompt and history.metadata.get('tool_calls'):
916946
messages.append({'role': history.metadata['role'], 'content': history.content,
@@ -922,31 +952,13 @@ async def async_messages_transform(self,
922952
"tool_call_id": history.metadata.get("tool_call_id")})
923953
if len(last_tool_calls) > 0 and len(tool_calls_map) == len(last_tool_calls):
924954
# Maintain the order of tool calls
925-
for tool_call_id in last_tool_calls:
926-
if tool_call_id not in tool_calls_map:
927-
raise AWorldRuntimeException(
928-
f"tool_calls mismatch! {tool_call_id} not found in {tool_calls_map}, last_tool_calls: {last_tool_calls}, messages: {messages}, histories: {histories}")
929-
messages.append(tool_calls_map.get(tool_call_id))
930-
matched_tool_call_ids.add(tool_call_id)
931-
tool_calls_map = {}
932-
last_tool_calls = []
955+
_append_complete_tool_results()
933956
elif len(tool_calls_map) > len(last_tool_calls):
934-
raise AWorldRuntimeException(
935-
f"tool_calls mismatch! {len(tool_calls_map)} tool messages > {len(last_tool_calls)} tool calls: "
936-
f"tool_calls_map={tool_calls_map}, last_tool_calls={last_tool_calls}, messages={messages}, histories={histories}")
937-
if len(tool_calls_map) == len(last_tool_calls):
938-
for tool_call_id in last_tool_calls:
939-
if tool_call_id not in tool_calls_map:
940-
raise AWorldRuntimeException(
941-
f"tool_calls mismatch! {tool_call_id} not found in {tool_calls_map}, last_tool_calls: {last_tool_calls}, messages: {messages}, histories: {histories}")
942-
messages.append(tool_calls_map.get(tool_call_id))
943-
matched_tool_call_ids.add(tool_call_id)
944-
tool_calls_map = {}
945-
last_tool_calls = []
957+
_drop_incomplete_tool_call_turn("more tool results than tool calls")
958+
if last_tool_calls and len(tool_calls_map) == len(last_tool_calls):
959+
_append_complete_tool_results()
946960
else:
947-
raise AWorldRuntimeException(
948-
f"tool_calls mismatch! {len(tool_calls_map)} tool messages != {len(last_tool_calls)} tool calls: "
949-
f"tool_calls_map={tool_calls_map}, last_tool_calls={last_tool_calls}, messages={messages}, histories={histories}")
961+
_drop_incomplete_tool_call_turn("end of history reached")
950962

951963
return messages
952964

tests/runners/test_memory_tool_result_compaction.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,96 @@ async def test_llm_message_replay_skips_duplicate_tool_result(monkeypatch):
304304
assert tool_messages[0]["content"] == [{"type": "text", "text": "first cron result"}]
305305

306306

307+
@pytest.mark.asyncio
308+
async def test_llm_message_replay_drops_incomplete_tool_call_turn(monkeypatch):
309+
meta = MessageMetadata(
310+
session_id="session-1",
311+
user_id="user-1",
312+
task_id="task-1",
313+
agent_id="agent-1",
314+
agent_name="Aworld",
315+
)
316+
ai_message = MemoryAIMessage(
317+
content="",
318+
tool_calls=[
319+
ToolCall.from_dict({
320+
"id": "cron__cron_tool:missing",
321+
"function": {"name": "cron__cron_tool", "arguments": "{}"},
322+
})
323+
],
324+
metadata=meta,
325+
)
326+
fake_memory = _FakeMemory()
327+
fake_memory.items = [(ai_message, None)]
328+
monkeypatch.setattr(
329+
"aworld.agents.llm_agent.MemoryFactory",
330+
type("MemoryFactory", (), {"instance": staticmethod(lambda: fake_memory)}),
331+
)
332+
333+
context = _build_context()
334+
agent = LLMAgent(
335+
name="Aworld",
336+
agent_id="agent-1",
337+
conf=AgentConfig(
338+
llm_model_name="test-model",
339+
llm_api_key="test-key",
340+
memory_config=AgentMemoryConfig(history_rounds=10),
341+
),
342+
)
343+
message = Message(headers={"context": context})
344+
345+
messages = await agent.async_messages_transform(
346+
image_urls=[],
347+
observation=Observation(action_result=[ActionResult(content="continue current turn")]),
348+
message=message,
349+
)
350+
351+
assert not any(message.get("tool_calls") for message in messages)
352+
assert not any(message.get("role") == "tool" for message in messages)
353+
354+
355+
@pytest.mark.asyncio
356+
async def test_llm_message_replay_skips_orphan_tool_result(monkeypatch):
357+
meta = MessageMetadata(
358+
session_id="session-1",
359+
user_id="user-1",
360+
task_id="task-1",
361+
agent_id="agent-1",
362+
agent_name="Aworld",
363+
)
364+
orphan_tool = MemoryToolMessage(
365+
content="orphan result",
366+
tool_call_id="missing-call",
367+
metadata=meta,
368+
)
369+
fake_memory = _FakeMemory()
370+
fake_memory.items = [(orphan_tool, None)]
371+
monkeypatch.setattr(
372+
"aworld.agents.llm_agent.MemoryFactory",
373+
type("MemoryFactory", (), {"instance": staticmethod(lambda: fake_memory)}),
374+
)
375+
376+
context = _build_context()
377+
agent = LLMAgent(
378+
name="Aworld",
379+
agent_id="agent-1",
380+
conf=AgentConfig(
381+
llm_model_name="test-model",
382+
llm_api_key="test-key",
383+
memory_config=AgentMemoryConfig(history_rounds=10),
384+
),
385+
)
386+
message = Message(headers={"context": context})
387+
388+
messages = await agent.async_messages_transform(
389+
image_urls=[],
390+
observation=Observation(action_result=[ActionResult(content="continue current turn")]),
391+
message=message,
392+
)
393+
394+
assert not any(message.get("role") == "tool" for message in messages)
395+
396+
307397
@pytest.mark.asyncio
308398
async def test_default_memory_handler_compacts_large_tool_results_by_char_length(monkeypatch):
309399
fake_memory = _FakeMemory()

0 commit comments

Comments
 (0)